From f37247f8b706f1855c9410817980a76437e7de86 Mon Sep 17 00:00:00 2001
From: Sting <loic.allegre@ensiie.fr>
Date: Mon, 17 Jul 2023 15:02:51 +0200
Subject: [PATCH] Add magnitude histogram in test mode

---
 cnn_madmom/segment.py | 27 +++++++++++++++++----------
 1 file changed, 17 insertions(+), 10 deletions(-)

diff --git a/cnn_madmom/segment.py b/cnn_madmom/segment.py
index 5e78190..b1455c7 100644
--- a/cnn_madmom/segment.py
+++ b/cnn_madmom/segment.py
@@ -160,8 +160,8 @@ if __name__ == "__main__":
     
     activation_smoothed = madmom.audio.signal.smooth(activation_function, 20)
     cnn_smoothed = madmom.audio.signal.smooth(cnn_function, 20)
-    onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=0.25, smooth=0)
-    onsets = np.array([o for o in onsets if cnn_smoothed[o] > 0.2])
+    onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=0.1, smooth=0)
+    onsets = np.array([o for o in onsets if cnn_smoothed[o] > 0.1])
 
     invalid_onsets_idx = []
     magnitude_window = [2,8]
@@ -172,14 +172,18 @@ if __name__ == "__main__":
     
     onsets = np.delete(onsets, invalid_onsets_idx)
 
-    filtered_onsets = []
-    for line in reference_syls:
-        syl_number = len(line) - 1
-        line_onsets = [o for o in onsets if (o >= line[0][0] and o <= line[-1][0])]
-        line_onsets.sort(reverse=True, key=(lambda x: activation_smoothed[x]))
-        filtered_onsets += line_onsets[0:syl_number]
-    
-    onsets = np.array(sorted(filtered_onsets))
+
+    if reference_syls:
+        filtered_onsets = []
+        for line in reference_syls:
+            syl_number = len(line) - 1
+            line_onsets = [o for o in onsets if (o >= line[0][0] and o <= line[-1][0])]
+            line_onsets.sort(reverse=True, key=(lambda x: activation_smoothed[x]))
+            if syl_number > len(line_onsets):
+                print("WARNING : failed to detect enough onsets in line")
+            filtered_onsets += line_onsets[0:syl_number]
+        
+        onsets = np.array(sorted(filtered_onsets))
     
     # Backtrack onsets to closest earlier local minimum
     if backtrack:
@@ -203,5 +207,8 @@ if __name__ == "__main__":
     axs[1].vlines(onsets, 0, 2, colors='red')
     axs[1].hlines([max(magnitude_threshold, 0.5)], 0, onsets[-1]+100, colors='black')
 
+    bins = np.arange(0, 2, 0.05)
+    hist, hist_axs = plt.subplots(nrows=1)
+    hist_axs.hist(magnitude, bins=bins)
 
     plt.show()
-- 
GitLab