diff --git a/cnn_madmom/segment.py b/cnn_madmom/segment.py index 5e781905f378a5375f23ef31d6938e1b1e549856..b1455c725c49ba52343f77b39094dfea5555f11f 100644 --- a/cnn_madmom/segment.py +++ b/cnn_madmom/segment.py @@ -160,8 +160,8 @@ if __name__ == "__main__": activation_smoothed = madmom.audio.signal.smooth(activation_function, 20) cnn_smoothed = madmom.audio.signal.smooth(cnn_function, 20) - onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=0.25, smooth=0) - onsets = np.array([o for o in onsets if cnn_smoothed[o] > 0.2]) + onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=0.1, smooth=0) + onsets = np.array([o for o in onsets if cnn_smoothed[o] > 0.1]) invalid_onsets_idx = [] magnitude_window = [2,8] @@ -172,14 +172,18 @@ if __name__ == "__main__": onsets = np.delete(onsets, invalid_onsets_idx) - filtered_onsets = [] - for line in reference_syls: - syl_number = len(line) - 1 - line_onsets = [o for o in onsets if (o >= line[0][0] and o <= line[-1][0])] - line_onsets.sort(reverse=True, key=(lambda x: activation_smoothed[x])) - filtered_onsets += line_onsets[0:syl_number] - - onsets = np.array(sorted(filtered_onsets)) + + if reference_syls: + filtered_onsets = [] + for line in reference_syls: + syl_number = len(line) - 1 + line_onsets = [o for o in onsets if (o >= line[0][0] and o <= line[-1][0])] + line_onsets.sort(reverse=True, key=(lambda x: activation_smoothed[x])) + if syl_number > len(line_onsets): + print("WARNING : failed to detect enough onsets in line") + filtered_onsets += line_onsets[0:syl_number] + + onsets = np.array(sorted(filtered_onsets)) # Backtrack onsets to closest earlier local minimum if backtrack: @@ -203,5 +207,8 @@ if __name__ == "__main__": axs[1].vlines(onsets, 0, 2, colors='red') axs[1].hlines([max(magnitude_threshold, 0.5)], 0, onsets[-1]+100, colors='black') + bins = np.arange(0, 2, 0.05) + hist, hist_axs = plt.subplots(nrows=1) + hist_axs.hist(magnitude, bins=bins) plt.show()