diff --git a/cnn_madmom/segment.py b/cnn_madmom/segment.py index 3a4a75ac47aca510ee63e33e339bfb1a1c307dc4..174424ff29dff56207a56d93d916ff2b241c00ba 100644 --- a/cnn_madmom/segment.py +++ b/cnn_madmom/segment.py @@ -20,6 +20,8 @@ def segment(songfile): filt_spec = madmom.audio.spectrogram.FilteredSpectrogram(spec, filterbank=madmom.audio.filters.LogFilterbank, num_bands=24) log_spec = madmom.audio.spectrogram.LogarithmicSpectrogram(filt_spec, add=1) + magnitude = np.max(log_spec[:,:100], axis=1) + cnn_function = cnn(sig) spectral_function = spectral(sig) spectral_function = spectral_function/(spectral_function.max()) @@ -31,9 +33,18 @@ def segment(songfile): activation_smoothed = madmom.audio.signal.smooth(activation_function, 20) cnn_smoothed = madmom.audio.signal.smooth(cnn_function, 20) - onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=0.6, smooth=0) + onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=0.5, smooth=0) onsets = np.array([o for o in onsets if cnn_smoothed[o] > 0.2]) + invalid_onsets_idx = [] + magnitude_window = [2,8] + magnitude_threshold = 0.8 + for i in range(len(onsets)): + if np.max(magnitude[onsets[i]+magnitude_window[0]:onsets[i]+magnitude_window[1]]) < magnitude_threshold: + invalid_onsets_idx.append(i) + + onsets = np.delete(onsets, invalid_onsets_idx) + if backtrack: # Backtrack onsets to closest earlier local minimum @@ -65,6 +76,8 @@ if __name__ == "__main__": filt_spec = madmom.audio.spectrogram.FilteredSpectrogram(spec, filterbank=madmom.audio.filters.LogFilterbank, num_bands=24) log_spec = madmom.audio.spectrogram.LogarithmicSpectrogram(filt_spec, add=1) + magnitude = np.max(log_spec[:,:100], axis=1) + cnn_function = cnn(sig) spectral_function = spectral(sig) spectral_function = spectral_function/(spectral_function.max()) @@ -76,8 +89,17 @@ if __name__ == "__main__": activation_smoothed = madmom.audio.signal.smooth(activation_function, 20) cnn_smoothed = madmom.audio.signal.smooth(cnn_function, 20) - onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=0.6, smooth=0) + onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=0.5, smooth=0) onsets = np.array([o for o in onsets if cnn_smoothed[o] > 0.2]) + + invalid_onsets_idx = [] + magnitude_window = [2,8] + magnitude_threshold = 0.8 + for i in range(len(onsets)): + if np.max(magnitude[onsets[i]+magnitude_window[0]:onsets[i]+magnitude_window[1]]) < magnitude_threshold: + invalid_onsets_idx.append(i) + + onsets = np.delete(onsets, invalid_onsets_idx) # Backtrack onsets to closest earlier local minimum if backtrack: @@ -91,8 +113,9 @@ if __name__ == "__main__": fig, axs = plt.subplots(nrows=2, sharex=True) axs[0].imshow(log_spec.T, origin='lower', aspect='auto') - axs[1].plot(cnn_smoothed) - axs[1].plot(spectral_function, color='green') + axs[1].plot(magnitude) + #axs[1].plot(cnn_smoothed) + #axs[1].plot(spectral_function, color='green') axs[1].plot(activation_smoothed, color='orange') axs[1].vlines(onsets, 0, 1, colors='red')