From ede80384549b7a828eed44b2486b4f31807244e4 Mon Sep 17 00:00:00 2001
From: Sting <loic.allegre@ensiie.fr>
Date: Mon, 17 Jul 2023 13:47:03 +0200
Subject: [PATCH] Use reference file for syl names and line timings

---
 assUtils.py           |  52 ++++++++++++---------
 autokara.py           |  13 ++++--
 cnn_madmom/segment.py | 105 ++++++++++++++++++++++++++++++++++++++----
 3 files changed, 134 insertions(+), 36 deletions(-)

diff --git a/assUtils.py b/assUtils.py
index e14f20a..7b5fd28 100644
--- a/assUtils.py
+++ b/assUtils.py
@@ -1,5 +1,6 @@
 import numpy as np
 import math
+import re
 
 
 def timeToDate(time):
@@ -21,6 +22,27 @@ def dateToTime(date):
     return int(date[8:10]) + secondsInCentiseconds
 
 
+
+def getSyls(ass_file):
+    SYLS = []
+    with open(ass_file, 'r') as f:
+        CONTENT = f.read()
+        LINES_KARA = re.compile(r"Comment:.*(\d+:\d{2}:\d{2}.\d{2}),(\d+:\d{2}:\d{2}.\d{2}),.*,karaoke,(.*)\n");
+        RGX_TAGS = re.compile(r"\{\\k(\d+)\}([^\{\n\r]*)")
+        for line in LINES_KARA.findall(CONTENT):
+            syl_line = []
+            lastTime = dateToTime(line[0])
+            for couple in RGX_TAGS.findall(line[2]):
+                syl_line.append([lastTime, couple[1], int(couple[0])])
+                lastTime += int(couple[0])
+            syl_line.append([lastTime, '', 0])
+            SYLS.append(syl_line)
+    return SYLS
+
+
+
+
+
 class AssWriter:
 
     def __init__(self):
@@ -68,30 +90,16 @@ Comment: 0,0:00:05.68,0:00:05.68,Default,,0,0,0,,
 '''
         self.file.write(header)
 
-    def writeSyls(self, syl_timings, syls_per_line=10000):
+    def writeSyls(self, syl_timings):
         bottom = False
-        last_syl_dur = 500
-        syl_index = 0
-        while syl_index < (len(syl_timings) - syls_per_line):
-            start_time = timeToDate(syl_timings[syl_index][0])
-            end_time = timeToDate(syl_timings[syl_index + syls_per_line][0])
+        for syl_line in syl_timings:
+            start_time = timeToDate(syl_line[0][0])
+            end_time = timeToDate(syl_line[-1][0])
             v_margin = (150 if bottom else 50)
             line = f'Dialogue: 0,{start_time},{end_time},Default,,0,0,{v_margin:d},,'
-            for i in range(syl_index, syl_index + syls_per_line):
-                syl_dur = round((syl_timings[i+1][0] - syl_timings[i][0]) * 100)
-                line += f'{{\k{syl_dur:d}}}{syl_timings[i][1]:s}'
+            for i in range(len(syl_line) - 1):
+                syl_dur = round((syl_line[i+1][0] - syl_line[i][0]) * 100)
+                line += f'{{\k{syl_dur:d}}}{syl_line[i][1]:s}'
             line += '\n'
             self.file.write(line)
-            syl_index += syls_per_line
-            bottom = not bottom
-
-        start_time = timeToDate(syl_timings[syl_index][0])
-        end_time = timeToDate(syl_timings[-1][0] + last_syl_dur//100)
-        v_margin = (150 if bottom else 50)
-        line = f'Dialogue: 0,{start_time},{end_time},Default,,0,0,{v_margin:d},,'
-        for i in range(syl_index, len(syl_timings) - 1):
-            syl_dur = round((syl_timings[i+1][0] - syl_timings[i][0]) * 100)
-            line += f'{{\k{syl_dur:d}}}{syl_timings[i][1]:s}'
-        line += f'{{\k{last_syl_dur:d}}}{syl_timings[-1][1]:s}\n'
-
-        self.file.write(line)
\ No newline at end of file
+            bottom = not bottom
\ No newline at end of file
diff --git a/autokara.py b/autokara.py
index e7f69b1..e1d294a 100644
--- a/autokara.py
+++ b/autokara.py
@@ -4,7 +4,7 @@ import demucs.separate
 import subprocess
 import shlex
 from pathlib import Path
-from assUtils import AssWriter
+from assUtils import AssWriter, getSyls
 
 from cnn_madmom.segment import segment
 
@@ -13,6 +13,7 @@ parser = argparse.ArgumentParser(description='AutoKara - Automatic karaoke timin
 parser.add_argument("source_file", type=str, help="The video/audio file to time")
 parser.add_argument("ass_file", type=str, help="The ASS file in which to output the karaoke")
 parser.add_argument("--vocals", action="store_true", help="Treat the input as vocals file, i.e. do not perform vocals extraction")
+parser.add_argument("--ref", help="Use an ASS file as reference")
 
 args = parser.parse_args()
 
@@ -39,15 +40,19 @@ if not args.vocals :
 else:
     vocals_file = args.source_file
 
+if args.ref:
+    reference_syls = getSyls(args.ref)
+else:
+    reference_syls = None
 
 print("Identifying syl starts...")
-onsets = segment(vocals_file)
-syls = [[t, 'la'] for t in onsets]
+syls = segment(vocals_file, reference_syls=reference_syls)
+print(syls)
 
 print("Syls found, writing ASS file...")
 writer = AssWriter()
 writer.openAss(ass_file)
 writer.writeHeader()
-writer.writeSyls(syls, syls_per_line=10)
+writer.writeSyls(syls)
 writer.closeAss()
 
diff --git a/cnn_madmom/segment.py b/cnn_madmom/segment.py
index 174424f..5e78190 100644
--- a/cnn_madmom/segment.py
+++ b/cnn_madmom/segment.py
@@ -1,12 +1,13 @@
 import madmom
 import numpy as np
 import sys
+import re
 import matplotlib.pyplot as plt
 from scipy.ndimage.filters import maximum_filter
 import scipy.signal as sg
 
 
-def segment(songfile):
+def segment(songfile, reference_syls=None, syls_per_line=10, last_syl_dur=500):
 
     delay = -4
     backtrack = False
@@ -33,12 +34,12 @@ def segment(songfile):
     
     activation_smoothed = madmom.audio.signal.smooth(activation_function, 20)
     cnn_smoothed = madmom.audio.signal.smooth(cnn_function, 20)
-    onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=0.5, smooth=0)
-    onsets = np.array([o for o in onsets if cnn_smoothed[o] > 0.2])
+    onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=0.1, smooth=0)
+    onsets = np.array([o for o in onsets if cnn_smoothed[o] > 0.1])
 
     invalid_onsets_idx = []
     magnitude_window = [2,8]
-    magnitude_threshold = 0.8
+    magnitude_threshold = 1.2
     for i in range(len(onsets)):
         if np.max(magnitude[onsets[i]+magnitude_window[0]:onsets[i]+magnitude_window[1]]) < magnitude_threshold:
             invalid_onsets_idx.append(i)
@@ -46,6 +47,19 @@ def segment(songfile):
     onsets = np.delete(onsets, invalid_onsets_idx)
 
 
+    if reference_syls:
+        filtered_onsets = []
+        for line in reference_syls:
+            syl_number = len(line) - 1
+            line_onsets = [o for o in onsets if (o >= line[0][0] and o <= line[-1][0])]
+            line_onsets.sort(reverse=True, key=(lambda x: activation_smoothed[x]))
+            if syl_number > len(line_onsets):
+                print("WARNING : failed to detect enough onsets in line")
+            filtered_onsets += line_onsets[0:syl_number]
+        
+        onsets = np.array(sorted(filtered_onsets))
+
+
     if backtrack:
         # Backtrack onsets to closest earlier local minimum
         backtrack_max_frames = 50
@@ -55,15 +69,72 @@ def segment(songfile):
                 onsets[i] -= 1
 
     onsets = (onsets + delay)/100
+    #print(onsets)
+
+    if reference_syls:
+        syls = []
+        onset_index = 0
+        for line in reference_syls:
+            #print(onset_index, " : ", line)
+            l = [[onsets[onset_index + i], line[i][1]] for i in range(len(line)-1)]
+            l.append([line[-1][0]/100, ''])
+            syls.append(l)
+            onset_index += (len(line) - 1)
+    else:
+        syls = []
+        onset_index = 0
+        for l in range(0, len(onsets), syls_per_line):
+            if onset_index + syls_per_line < len(onsets):
+                line = [[onset, 'la'] for onset in onsets[onset_index:onset_index+syls_per_line]]
+            else:
+                line = [[onset, 'la'] for onset in onsets[onset_index:]]
+            if onset_index + syls_per_line + 1 < len(onsets):
+                line.append([onsets[onset_index+syls_per_line+1], ''])
+            else:
+                line.append([onsets[-1] + last_syl_dur/100, ''])
+            syls.append(line)
+            onset_index += syls_per_line
+
+    return syls
+
 
-    print(onsets)
 
-    return onsets
+if __name__ == "__main__":
+
+    def dateToTime(date):
+        """
+        The `date` should be in the following format: H:MM:SS.cs
+        """
+        hourInMinuts = int(date[0:1]) * 60
+        minutsInSeconds = (int(date[2:4]) + hourInMinuts) * 60
+        secondsInCentiseconds = (int(date[5:7]) + minutsInSeconds) * 100
+        return int(date[8:10]) + secondsInCentiseconds
+
+
+    def getSyls(ass_file):
+        SYLS = []
+        with open(ass_file, 'r') as f:
+            CONTENT = f.read()
+            LINES_KARA = re.compile(r"Comment:.*(\d+:\d{2}:\d{2}.\d{2}),(\d+:\d{2}:\d{2}.\d{2}),.*,karaoke,(.*)\n");
+            RGX_TAGS = re.compile(r"\{\\k(\d+)\}([^\{\n\r]*)")
+            for line in LINES_KARA.findall(CONTENT):
+                syl_line = []
+                lastTime = dateToTime(line[0])
+                for couple in RGX_TAGS.findall(line[2]):
+                    syl_line.append((lastTime, couple[1], int(couple[0])))
+                    lastTime += int(couple[0])
+                syl_line.append([lastTime, '', 0])
+                SYLS.append(syl_line)
+        return SYLS
+
 
 
 
-if __name__ == "__main__":
     songfile = sys.argv[1]
+    if(len(sys.argv) == 3):
+        reference_syls = getSyls(sys.argv[2])
+    
+    print(reference_syls)
 
     backtrack = False
 
@@ -89,17 +160,26 @@ if __name__ == "__main__":
     
     activation_smoothed = madmom.audio.signal.smooth(activation_function, 20)
     cnn_smoothed = madmom.audio.signal.smooth(cnn_function, 20)
-    onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=0.5, smooth=0)
+    onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=0.25, smooth=0)
     onsets = np.array([o for o in onsets if cnn_smoothed[o] > 0.2])
 
     invalid_onsets_idx = []
     magnitude_window = [2,8]
-    magnitude_threshold = 0.8
+    magnitude_threshold = 1.2
     for i in range(len(onsets)):
         if np.max(magnitude[onsets[i]+magnitude_window[0]:onsets[i]+magnitude_window[1]]) < magnitude_threshold:
             invalid_onsets_idx.append(i)
     
     onsets = np.delete(onsets, invalid_onsets_idx)
+
+    filtered_onsets = []
+    for line in reference_syls:
+        syl_number = len(line) - 1
+        line_onsets = [o for o in onsets if (o >= line[0][0] and o <= line[-1][0])]
+        line_onsets.sort(reverse=True, key=(lambda x: activation_smoothed[x]))
+        filtered_onsets += line_onsets[0:syl_number]
+    
+    onsets = np.array(sorted(filtered_onsets))
     
     # Backtrack onsets to closest earlier local minimum
     if backtrack:
@@ -111,12 +191,17 @@ if __name__ == "__main__":
 
     print(onsets/100)
 
+    reference_onsets = [syl[0]+8 for line in reference_syls for syl in line[:-1]]
+
     fig, axs = plt.subplots(nrows=2, sharex=True)
     axs[0].imshow(log_spec.T, origin='lower', aspect='auto')
+    axs[0].vlines(reference_onsets, 0, 140, colors='red')
     axs[1].plot(magnitude)
     #axs[1].plot(cnn_smoothed)
     #axs[1].plot(spectral_function, color='green')
     axs[1].plot(activation_smoothed, color='orange')
-    axs[1].vlines(onsets, 0, 1, colors='red')
+    axs[1].vlines(onsets, 0, 2, colors='red')
+    axs[1].hlines([max(magnitude_threshold, 0.5)], 0, onsets[-1]+100, colors='black')
+
 
     plt.show()
-- 
GitLab