From 8a8846539e1f5374c9757f7d5d8a4d9d82066a26 Mon Sep 17 00:00:00 2001
From: Sting <loic.allegre@ensiie.fr>
Date: Tue, 18 Jul 2023 12:53:05 +0200
Subject: [PATCH] Fixes

---
 cnn_madmom/segment.py | 19 +++++++++++++++----
 1 file changed, 15 insertions(+), 4 deletions(-)

diff --git a/cnn_madmom/segment.py b/cnn_madmom/segment.py
index 54257eb..2f8ac12 100644
--- a/cnn_madmom/segment.py
+++ b/cnn_madmom/segment.py
@@ -3,7 +3,6 @@ import numpy as np
 import sys
 import re
 import matplotlib.pyplot as plt
-from scipy.ndimage.filters import maximum_filter
 import scipy.signal as sg
 import parselmouth
 
@@ -17,6 +16,7 @@ def segment(songfile, reference_syls=None, syls_per_line=10, last_syl_dur=500):
     spectral = madmom.features.onsets.SpectralOnsetProcessor('modified_kullback_leibler')
 
     sig = madmom.audio.signal.Signal(songfile, num_channels=1)
+    parsel = parselmouth.Sound(sig)
 
     spec = madmom.audio.spectrogram.Spectrogram(sig)
     filt_spec = madmom.audio.spectrogram.FilteredSpectrogram(spec, filterbank=madmom.audio.filters.LogFilterbank, num_bands=24)
@@ -38,11 +38,22 @@ def segment(songfile, reference_syls=None, syls_per_line=10, last_syl_dur=500):
     onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=0.1, smooth=0)
     #onsets = np.array([o for o in onsets if cnn_smoothed[o] > 0.1])
 
+    pitch = parsel.to_pitch()
+    pitch_values = pitch.selected_array['frequency']
+
+    pad_before = round(pitch.xs()[0]*100)
+    pad_after = len(magnitude) - len(pitch_values) - pad_before
+
+    pitch_values = np.pad(pitch_values, (pad_before, pad_after), 'constant', constant_values=(0,0))
+
+    mask_function = magnitude * pitch_values
+    mask_function = mask_function/np.max(mask_function)
+    mask_threshold = 0.15
+    mask_window = [1,6]
     invalid_onsets_idx = []
-    magnitude_window = [2,8]
-    magnitude_threshold = 1.2
+    
     for i in range(len(onsets)):
-        if np.max(magnitude[onsets[i]+magnitude_window[0]:onsets[i]+magnitude_window[1]]) < magnitude_threshold:
+        if np.max(mask_function[onsets[i]+mask_window[0]:onsets[i]+mask_window[1]]) < mask_threshold:
             invalid_onsets_idx.append(i)
     
     onsets = np.delete(onsets, invalid_onsets_idx)
-- 
GitLab