diff --git a/assUtils.py b/assUtils.py index e14f20abfb682670d31d3cb371150d694e575be9..7b5fd284d7923cd7537cf24aad9062dc6f34df30 100644 --- a/assUtils.py +++ b/assUtils.py @@ -1,5 +1,6 @@ import numpy as np import math +import re def timeToDate(time): @@ -21,6 +22,27 @@ def dateToTime(date): return int(date[8:10]) + secondsInCentiseconds + +def getSyls(ass_file): + SYLS = [] + with open(ass_file, 'r') as f: + CONTENT = f.read() + LINES_KARA = re.compile(r"Comment:.*(\d+:\d{2}:\d{2}.\d{2}),(\d+:\d{2}:\d{2}.\d{2}),.*,karaoke,(.*)\n"); + RGX_TAGS = re.compile(r"\{\\k(\d+)\}([^\{\n\r]*)") + for line in LINES_KARA.findall(CONTENT): + syl_line = [] + lastTime = dateToTime(line[0]) + for couple in RGX_TAGS.findall(line[2]): + syl_line.append([lastTime, couple[1], int(couple[0])]) + lastTime += int(couple[0]) + syl_line.append([lastTime, '', 0]) + SYLS.append(syl_line) + return SYLS + + + + + class AssWriter: def __init__(self): @@ -68,30 +90,16 @@ Comment: 0,0:00:05.68,0:00:05.68,Default,,0,0,0,, ''' self.file.write(header) - def writeSyls(self, syl_timings, syls_per_line=10000): + def writeSyls(self, syl_timings): bottom = False - last_syl_dur = 500 - syl_index = 0 - while syl_index < (len(syl_timings) - syls_per_line): - start_time = timeToDate(syl_timings[syl_index][0]) - end_time = timeToDate(syl_timings[syl_index + syls_per_line][0]) + for syl_line in syl_timings: + start_time = timeToDate(syl_line[0][0]) + end_time = timeToDate(syl_line[-1][0]) v_margin = (150 if bottom else 50) line = f'Dialogue: 0,{start_time},{end_time},Default,,0,0,{v_margin:d},,' - for i in range(syl_index, syl_index + syls_per_line): - syl_dur = round((syl_timings[i+1][0] - syl_timings[i][0]) * 100) - line += f'{{\k{syl_dur:d}}}{syl_timings[i][1]:s}' + for i in range(len(syl_line) - 1): + syl_dur = round((syl_line[i+1][0] - syl_line[i][0]) * 100) + line += f'{{\k{syl_dur:d}}}{syl_line[i][1]:s}' line += '\n' self.file.write(line) - syl_index += syls_per_line - bottom = not bottom - - start_time = timeToDate(syl_timings[syl_index][0]) - end_time = timeToDate(syl_timings[-1][0] + last_syl_dur//100) - v_margin = (150 if bottom else 50) - line = f'Dialogue: 0,{start_time},{end_time},Default,,0,0,{v_margin:d},,' - for i in range(syl_index, len(syl_timings) - 1): - syl_dur = round((syl_timings[i+1][0] - syl_timings[i][0]) * 100) - line += f'{{\k{syl_dur:d}}}{syl_timings[i][1]:s}' - line += f'{{\k{last_syl_dur:d}}}{syl_timings[-1][1]:s}\n' - - self.file.write(line) \ No newline at end of file + bottom = not bottom \ No newline at end of file diff --git a/autokara.py b/autokara.py index e7f69b174336edf669cc884c52275ae95fdea7f7..e1d294a1fb3f4fb273e1531b36f0e37bc22e3b75 100644 --- a/autokara.py +++ b/autokara.py @@ -4,7 +4,7 @@ import demucs.separate import subprocess import shlex from pathlib import Path -from assUtils import AssWriter +from assUtils import AssWriter, getSyls from cnn_madmom.segment import segment @@ -13,6 +13,7 @@ parser = argparse.ArgumentParser(description='AutoKara - Automatic karaoke timin parser.add_argument("source_file", type=str, help="The video/audio file to time") parser.add_argument("ass_file", type=str, help="The ASS file in which to output the karaoke") parser.add_argument("--vocals", action="store_true", help="Treat the input as vocals file, i.e. do not perform vocals extraction") +parser.add_argument("--ref", help="Use an ASS file as reference") args = parser.parse_args() @@ -39,15 +40,19 @@ if not args.vocals : else: vocals_file = args.source_file +if args.ref: + reference_syls = getSyls(args.ref) +else: + reference_syls = None print("Identifying syl starts...") -onsets = segment(vocals_file) -syls = [[t, 'la'] for t in onsets] +syls = segment(vocals_file, reference_syls=reference_syls) +print(syls) print("Syls found, writing ASS file...") writer = AssWriter() writer.openAss(ass_file) writer.writeHeader() -writer.writeSyls(syls, syls_per_line=10) +writer.writeSyls(syls) writer.closeAss() diff --git a/cnn_madmom/segment.py b/cnn_madmom/segment.py index 174424ff29dff56207a56d93d916ff2b241c00ba..5e781905f378a5375f23ef31d6938e1b1e549856 100644 --- a/cnn_madmom/segment.py +++ b/cnn_madmom/segment.py @@ -1,12 +1,13 @@ import madmom import numpy as np import sys +import re import matplotlib.pyplot as plt from scipy.ndimage.filters import maximum_filter import scipy.signal as sg -def segment(songfile): +def segment(songfile, reference_syls=None, syls_per_line=10, last_syl_dur=500): delay = -4 backtrack = False @@ -33,12 +34,12 @@ def segment(songfile): activation_smoothed = madmom.audio.signal.smooth(activation_function, 20) cnn_smoothed = madmom.audio.signal.smooth(cnn_function, 20) - onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=0.5, smooth=0) - onsets = np.array([o for o in onsets if cnn_smoothed[o] > 0.2]) + onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=0.1, smooth=0) + onsets = np.array([o for o in onsets if cnn_smoothed[o] > 0.1]) invalid_onsets_idx = [] magnitude_window = [2,8] - magnitude_threshold = 0.8 + magnitude_threshold = 1.2 for i in range(len(onsets)): if np.max(magnitude[onsets[i]+magnitude_window[0]:onsets[i]+magnitude_window[1]]) < magnitude_threshold: invalid_onsets_idx.append(i) @@ -46,6 +47,19 @@ def segment(songfile): onsets = np.delete(onsets, invalid_onsets_idx) + if reference_syls: + filtered_onsets = [] + for line in reference_syls: + syl_number = len(line) - 1 + line_onsets = [o for o in onsets if (o >= line[0][0] and o <= line[-1][0])] + line_onsets.sort(reverse=True, key=(lambda x: activation_smoothed[x])) + if syl_number > len(line_onsets): + print("WARNING : failed to detect enough onsets in line") + filtered_onsets += line_onsets[0:syl_number] + + onsets = np.array(sorted(filtered_onsets)) + + if backtrack: # Backtrack onsets to closest earlier local minimum backtrack_max_frames = 50 @@ -55,15 +69,72 @@ def segment(songfile): onsets[i] -= 1 onsets = (onsets + delay)/100 + #print(onsets) + + if reference_syls: + syls = [] + onset_index = 0 + for line in reference_syls: + #print(onset_index, " : ", line) + l = [[onsets[onset_index + i], line[i][1]] for i in range(len(line)-1)] + l.append([line[-1][0]/100, '']) + syls.append(l) + onset_index += (len(line) - 1) + else: + syls = [] + onset_index = 0 + for l in range(0, len(onsets), syls_per_line): + if onset_index + syls_per_line < len(onsets): + line = [[onset, 'la'] for onset in onsets[onset_index:onset_index+syls_per_line]] + else: + line = [[onset, 'la'] for onset in onsets[onset_index:]] + if onset_index + syls_per_line + 1 < len(onsets): + line.append([onsets[onset_index+syls_per_line+1], '']) + else: + line.append([onsets[-1] + last_syl_dur/100, '']) + syls.append(line) + onset_index += syls_per_line + + return syls + - print(onsets) - return onsets +if __name__ == "__main__": + + def dateToTime(date): + """ + The `date` should be in the following format: H:MM:SS.cs + """ + hourInMinuts = int(date[0:1]) * 60 + minutsInSeconds = (int(date[2:4]) + hourInMinuts) * 60 + secondsInCentiseconds = (int(date[5:7]) + minutsInSeconds) * 100 + return int(date[8:10]) + secondsInCentiseconds + + + def getSyls(ass_file): + SYLS = [] + with open(ass_file, 'r') as f: + CONTENT = f.read() + LINES_KARA = re.compile(r"Comment:.*(\d+:\d{2}:\d{2}.\d{2}),(\d+:\d{2}:\d{2}.\d{2}),.*,karaoke,(.*)\n"); + RGX_TAGS = re.compile(r"\{\\k(\d+)\}([^\{\n\r]*)") + for line in LINES_KARA.findall(CONTENT): + syl_line = [] + lastTime = dateToTime(line[0]) + for couple in RGX_TAGS.findall(line[2]): + syl_line.append((lastTime, couple[1], int(couple[0]))) + lastTime += int(couple[0]) + syl_line.append([lastTime, '', 0]) + SYLS.append(syl_line) + return SYLS + -if __name__ == "__main__": songfile = sys.argv[1] + if(len(sys.argv) == 3): + reference_syls = getSyls(sys.argv[2]) + + print(reference_syls) backtrack = False @@ -89,17 +160,26 @@ if __name__ == "__main__": activation_smoothed = madmom.audio.signal.smooth(activation_function, 20) cnn_smoothed = madmom.audio.signal.smooth(cnn_function, 20) - onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=0.5, smooth=0) + onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=0.25, smooth=0) onsets = np.array([o for o in onsets if cnn_smoothed[o] > 0.2]) invalid_onsets_idx = [] magnitude_window = [2,8] - magnitude_threshold = 0.8 + magnitude_threshold = 1.2 for i in range(len(onsets)): if np.max(magnitude[onsets[i]+magnitude_window[0]:onsets[i]+magnitude_window[1]]) < magnitude_threshold: invalid_onsets_idx.append(i) onsets = np.delete(onsets, invalid_onsets_idx) + + filtered_onsets = [] + for line in reference_syls: + syl_number = len(line) - 1 + line_onsets = [o for o in onsets if (o >= line[0][0] and o <= line[-1][0])] + line_onsets.sort(reverse=True, key=(lambda x: activation_smoothed[x])) + filtered_onsets += line_onsets[0:syl_number] + + onsets = np.array(sorted(filtered_onsets)) # Backtrack onsets to closest earlier local minimum if backtrack: @@ -111,12 +191,17 @@ if __name__ == "__main__": print(onsets/100) + reference_onsets = [syl[0]+8 for line in reference_syls for syl in line[:-1]] + fig, axs = plt.subplots(nrows=2, sharex=True) axs[0].imshow(log_spec.T, origin='lower', aspect='auto') + axs[0].vlines(reference_onsets, 0, 140, colors='red') axs[1].plot(magnitude) #axs[1].plot(cnn_smoothed) #axs[1].plot(spectral_function, color='green') axs[1].plot(activation_smoothed, color='orange') - axs[1].vlines(onsets, 0, 1, colors='red') + axs[1].vlines(onsets, 0, 2, colors='red') + axs[1].hlines([max(magnitude_threshold, 0.5)], 0, onsets[-1]+100, colors='black') + plt.show()