From 8e51f6dbf817003026a3096538bca8cdb8e3d479 Mon Sep 17 00:00:00 2001
From: Sting <lallegre26@gmail.com>
Date: Mon, 17 Jul 2023 20:51:02 +0200
Subject: [PATCH] Change threshold + report missing syllables

---
 cnn_madmom/segment.py | 12 ++++++++----
 1 file changed, 8 insertions(+), 4 deletions(-)

diff --git a/cnn_madmom/segment.py b/cnn_madmom/segment.py
index b1455c7..9572551 100644
--- a/cnn_madmom/segment.py
+++ b/cnn_madmom/segment.py
@@ -35,7 +35,7 @@ def segment(songfile, reference_syls=None, syls_per_line=10, last_syl_dur=500):
     activation_smoothed = madmom.audio.signal.smooth(activation_function, 20)
     cnn_smoothed = madmom.audio.signal.smooth(cnn_function, 20)
     onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=0.1, smooth=0)
-    onsets = np.array([o for o in onsets if cnn_smoothed[o] > 0.1])
+    #onsets = np.array([o for o in onsets if cnn_smoothed[o] > 0.1])
 
     invalid_onsets_idx = []
     magnitude_window = [2,8]
@@ -49,12 +49,14 @@ def segment(songfile, reference_syls=None, syls_per_line=10, last_syl_dur=500):
 
     if reference_syls:
         filtered_onsets = []
+        line_index = 0
         for line in reference_syls:
+            line_index += 1
             syl_number = len(line) - 1
             line_onsets = [o for o in onsets if (o >= line[0][0] and o <= line[-1][0])]
             line_onsets.sort(reverse=True, key=(lambda x: activation_smoothed[x]))
             if syl_number > len(line_onsets):
-                print("WARNING : failed to detect enough onsets in line")
+                print("WARNING : failed to detect enough onsets in line %d (%d, %d)" % (line_index, line[0][0], line[-1][0]))
             filtered_onsets += line_onsets[0:syl_number]
         
         onsets = np.array(sorted(filtered_onsets))
@@ -161,7 +163,7 @@ if __name__ == "__main__":
     activation_smoothed = madmom.audio.signal.smooth(activation_function, 20)
     cnn_smoothed = madmom.audio.signal.smooth(cnn_function, 20)
     onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=0.1, smooth=0)
-    onsets = np.array([o for o in onsets if cnn_smoothed[o] > 0.1])
+    #onsets = np.array([o for o in onsets if cnn_smoothed[o] > 0.1])
 
     invalid_onsets_idx = []
     magnitude_window = [2,8]
@@ -175,12 +177,14 @@ if __name__ == "__main__":
 
     if reference_syls:
         filtered_onsets = []
+        line_index = 0
         for line in reference_syls:
+            line_index += 1
             syl_number = len(line) - 1
             line_onsets = [o for o in onsets if (o >= line[0][0] and o <= line[-1][0])]
             line_onsets.sort(reverse=True, key=(lambda x: activation_smoothed[x]))
             if syl_number > len(line_onsets):
-                print("WARNING : failed to detect enough onsets in line")
+                print("WARNING : failed to detect enough onsets in line %d (%d, %d)" % (line_index, line[0][0], line[-1][0]))
             filtered_onsets += line_onsets[0:syl_number]
         
         onsets = np.array(sorted(filtered_onsets))
-- 
GitLab