diff --git a/cnn_madmom/segment.py b/cnn_madmom/segment.py
index 3a4a75ac47aca510ee63e33e339bfb1a1c307dc4..174424ff29dff56207a56d93d916ff2b241c00ba 100644
--- a/cnn_madmom/segment.py
+++ b/cnn_madmom/segment.py
@@ -20,6 +20,8 @@ def segment(songfile):
     filt_spec = madmom.audio.spectrogram.FilteredSpectrogram(spec, filterbank=madmom.audio.filters.LogFilterbank, num_bands=24)
     log_spec = madmom.audio.spectrogram.LogarithmicSpectrogram(filt_spec, add=1)
 
+    magnitude = np.max(log_spec[:,:100], axis=1)
+
     cnn_function = cnn(sig)
     spectral_function = spectral(sig)
     spectral_function = spectral_function/(spectral_function.max())
@@ -31,9 +33,18 @@ def segment(songfile):
     
     activation_smoothed = madmom.audio.signal.smooth(activation_function, 20)
     cnn_smoothed = madmom.audio.signal.smooth(cnn_function, 20)
-    onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=0.6, smooth=0)
+    onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=0.5, smooth=0)
     onsets = np.array([o for o in onsets if cnn_smoothed[o] > 0.2])
 
+    invalid_onsets_idx = []
+    magnitude_window = [2,8]
+    magnitude_threshold = 0.8
+    for i in range(len(onsets)):
+        if np.max(magnitude[onsets[i]+magnitude_window[0]:onsets[i]+magnitude_window[1]]) < magnitude_threshold:
+            invalid_onsets_idx.append(i)
+    
+    onsets = np.delete(onsets, invalid_onsets_idx)
+
 
     if backtrack:
         # Backtrack onsets to closest earlier local minimum
@@ -65,6 +76,8 @@ if __name__ == "__main__":
     filt_spec = madmom.audio.spectrogram.FilteredSpectrogram(spec, filterbank=madmom.audio.filters.LogFilterbank, num_bands=24)
     log_spec = madmom.audio.spectrogram.LogarithmicSpectrogram(filt_spec, add=1)
 
+    magnitude = np.max(log_spec[:,:100], axis=1)
+
     cnn_function = cnn(sig)
     spectral_function = spectral(sig)
     spectral_function = spectral_function/(spectral_function.max())
@@ -76,8 +89,17 @@ if __name__ == "__main__":
     
     activation_smoothed = madmom.audio.signal.smooth(activation_function, 20)
     cnn_smoothed = madmom.audio.signal.smooth(cnn_function, 20)
-    onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=0.6, smooth=0)
+    onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=0.5, smooth=0)
     onsets = np.array([o for o in onsets if cnn_smoothed[o] > 0.2])
+
+    invalid_onsets_idx = []
+    magnitude_window = [2,8]
+    magnitude_threshold = 0.8
+    for i in range(len(onsets)):
+        if np.max(magnitude[onsets[i]+magnitude_window[0]:onsets[i]+magnitude_window[1]]) < magnitude_threshold:
+            invalid_onsets_idx.append(i)
+    
+    onsets = np.delete(onsets, invalid_onsets_idx)
     
     # Backtrack onsets to closest earlier local minimum
     if backtrack:
@@ -91,8 +113,9 @@ if __name__ == "__main__":
 
     fig, axs = plt.subplots(nrows=2, sharex=True)
     axs[0].imshow(log_spec.T, origin='lower', aspect='auto')
-    axs[1].plot(cnn_smoothed)
-    axs[1].plot(spectral_function, color='green')
+    axs[1].plot(magnitude)
+    #axs[1].plot(cnn_smoothed)
+    #axs[1].plot(spectral_function, color='green')
     axs[1].plot(activation_smoothed, color='orange')
     axs[1].vlines(onsets, 0, 1, colors='red')