From 399ab88fdc244c444e4a0d631ef7200bbcebb8b0 Mon Sep 17 00:00:00 2001
From: Sting <loic.allegre@ensiie.fr>
Date: Thu, 13 Jul 2023 13:56:17 +0200
Subject: [PATCH] Try with modified Kullback-Leibler

---
 cnn_madmom/segment.py | 45 ++++++++++++++++++++++++-------------------
 1 file changed, 25 insertions(+), 20 deletions(-)

diff --git a/cnn_madmom/segment.py b/cnn_madmom/segment.py
index 40034d3..a86921d 100644
--- a/cnn_madmom/segment.py
+++ b/cnn_madmom/segment.py
@@ -7,7 +7,8 @@ from scipy.ndimage.filters import maximum_filter
 
 def segment(songfile):
 
-    delay = 0
+    delay = -4
+    backtrack = False
 
     cnn = madmom.features.onsets.CNNOnsetProcessor()
     spectral = madmom.features.onsets.SpectralOnsetProcessor('complex_domain')
@@ -21,23 +22,24 @@ def segment(songfile):
     spectral_function = spectral(songfile, num_channels=1)
     spectral_function = spectral_function/(spectral_function.max())
     
-    activation_function = 0.5*cnn_function + 0.5*spectral_function
-    #activation_function = (2 * cnn_function * spectral_function)/(cnn_function + spectral_function)
+    #activation_function = 0.5*cnn_function + 0.5*spectral_function
+    activation_function = (2 * cnn_function * spectral_function)/(cnn_function + spectral_function)
     #activation_function = np.where(spectral_function > 0.14, cnn_function, 0)
     #onsets = proc(activation_function)
     
     activation_smoothed = madmom.audio.signal.smooth(activation_function, 20)
     cnn_smoothed = madmom.audio.signal.smooth(cnn_function, 20)
-    onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=1, smooth=0)
+    onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=0.6, smooth=0)
     onsets = np.array([o for o in onsets if cnn_smoothed[o] > 0.2])
 
 
-    # Backtrack onsets to closest earlier local minimum
-    backtrack_max_frames = 50
-    for i in range(len(onsets)):
-        initial_onset = onsets[i]
-        while(activation_smoothed[onsets[i] - 1] < activation_smoothed[onsets[i]] and onsets[i] > initial_onset - backtrack_max_frames):
-            onsets[i] -= 1
+    if backtrack:
+        # Backtrack onsets to closest earlier local minimum
+        backtrack_max_frames = 50
+        for i in range(len(onsets)):
+            initial_onset = onsets[i]
+            while(activation_smoothed[onsets[i] - 1] < activation_smoothed[onsets[i]] and onsets[i] > initial_onset - backtrack_max_frames):
+                onsets[i] -= 1
 
     onsets = (onsets + delay)/100
 
@@ -50,8 +52,10 @@ def segment(songfile):
 if __name__ == "__main__":
     songfile = sys.argv[1]
 
+    backtrack = False
+
     cnn = madmom.features.onsets.CNNOnsetProcessor()
-    spectral = madmom.features.onsets.SpectralOnsetProcessor('complex_domain')
+    spectral = madmom.features.onsets.SpectralOnsetProcessor('modified_kullback_leibler')
 
 
     spec = spec = madmom.audio.spectrogram.Spectrogram(songfile, num_channels=1)
@@ -62,22 +66,23 @@ if __name__ == "__main__":
     spectral_function = spectral(songfile, num_channels=1)
     spectral_function = spectral_function/(spectral_function.max())
     
-    activation_function = 0.5*cnn_function + 0.5*spectral_function
-    #activation_function = (2 * cnn_function * spectral_function)/(cnn_function + spectral_function)
+    #activation_function = 0.5*cnn_function + 0.5*spectral_function
+    activation_function = (2 * cnn_function * spectral_function)/(cnn_function + spectral_function)
     #activation_function = np.where(spectral_function > 0.14, cnn_function, 0)
     #onsets = proc(activation_function)
     
     activation_smoothed = madmom.audio.signal.smooth(activation_function, 20)
     cnn_smoothed = madmom.audio.signal.smooth(cnn_function, 20)
-    onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=1, smooth=0)
+    onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=0.6, smooth=0)
     onsets = np.array([o for o in onsets if cnn_smoothed[o] > 0.2])
     
     # Backtrack onsets to closest earlier local minimum
-    backtrack_max_frames = 50
-    for i in range(len(onsets)):
-        initial_onset = onsets[i]
-        while(activation_smoothed[onsets[i] - 1] < activation_smoothed[onsets[i]] and onsets[i] > initial_onset - backtrack_max_frames):
-            onsets[i] -= 1
+    if backtrack:
+        backtrack_max_frames = 50
+        for i in range(len(onsets)):
+            initial_onset = onsets[i]
+            while(activation_smoothed[onsets[i] - 1] < activation_smoothed[onsets[i]] and onsets[i] > initial_onset - backtrack_max_frames):
+                onsets[i] -= 1
 
     print(onsets/100)
 
@@ -85,7 +90,7 @@ if __name__ == "__main__":
     axs[0].imshow(log_spec.T, origin='lower', aspect='auto')
     axs[1].plot(cnn_smoothed)
     axs[1].plot(spectral_function, color='green')
-    axs[1].plot(activation_smoothed, color='pink')
+    axs[1].plot(activation_smoothed, color='orange')
     axs[1].vlines(onsets, 0, 1, colors='red')
 
     plt.show()
-- 
GitLab