diff --git a/README.md b/README.md index 04b43ea3eeac6eb6cdaf19d7e733738be4f5a9a6..6b712113d26cedb45fc24f1582b1a2168638b7e7 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,6 @@ An introduction to neural networks and deep learning: [Using CNNs on spectrogram images](https://www.ofai.at/~jan.schlueter/pubs/2014_icassp.pdf) (Schlüter, Böck, 2014) : - [MADMOM implementation](https://madmom.readthedocs.io/en/v0.16/modules/features/onsets.html) - - Python implementation for Taiko rythm games : https://github.com/seiichiinoue/odcnn ### Other methods @@ -70,27 +69,45 @@ Having a CUDA-capable GPU is optional, but can greatly reduce processing time in # Use -## Inference +## Autokara -To execute AutoKara on a MKV video file : +To execute AutoKara from scratch on a MKV video file : ```bash $ python autokara.py video.mkv output.ass ``` +To execute AutoKara with existing syl splits and line timings : +```bash +$ python autokara.py video.mkv output.ass --ref reference.ass +``` + To execute AutoKara on a (pre-extracted) WAV vocals file : ```bash $ python autokara.py vocals.wav output.ass --vocals ``` +## Useful scripts To only extract .wav audio from a MKV file : ```bash $ ./extractWav.sh source_video output_audio ``` +To only extract .ass sub file from a MKV file : +```bash +$ ./extractAss.sh source_video output_subs +``` + To only separate vocals from instruments in an audio file : ```bash demucs --two-stems=vocals -o output_folder audio_file.wav ``` +Batch preprocessing (vocals + ASS extraction) of all videos in a directory : +```bash +$ ./preprocess_media.sh video_folder output_folder +``` + + + diff --git a/assUtils.py b/assUtils.py index 3345d5b1665b0c9b21812ab3fcfbf6a9ee3a82c8..7b5fd284d7923cd7537cf24aad9062dc6f34df30 100644 --- a/assUtils.py +++ b/assUtils.py @@ -1,5 +1,6 @@ import numpy as np import math +import re def timeToDate(time): @@ -21,6 +22,27 @@ def dateToTime(date): return int(date[8:10]) + secondsInCentiseconds + +def getSyls(ass_file): + SYLS = [] + with open(ass_file, 'r') as f: + CONTENT = f.read() + LINES_KARA = re.compile(r"Comment:.*(\d+:\d{2}:\d{2}.\d{2}),(\d+:\d{2}:\d{2}.\d{2}),.*,karaoke,(.*)\n"); + RGX_TAGS = re.compile(r"\{\\k(\d+)\}([^\{\n\r]*)") + for line in LINES_KARA.findall(CONTENT): + syl_line = [] + lastTime = dateToTime(line[0]) + for couple in RGX_TAGS.findall(line[2]): + syl_line.append([lastTime, couple[1], int(couple[0])]) + lastTime += int(couple[0]) + syl_line.append([lastTime, '', 0]) + SYLS.append(syl_line) + return SYLS + + + + + class AssWriter: def __init__(self): @@ -39,38 +61,45 @@ ScriptType: v4.00+ WrapStyle: 0 ScaledBorderAndShadow: yes YCbCr Matrix: None +PlayResX: 1920 +PlayResY: 1080 [Aegisub Project Garbage] [V4+ Styles] Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, OutlineColour, BackColour, Bold, Italic, Underline, StrikeOut, ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, Shadow, Alignment, MarginL, MarginR, MarginV, Encoding -Style: Default,Arial,48,&H00FFFFFF,&H000000FF,&H00000000,&H00000000,0,0,0,0,100,100,0,0,1,2,2,2,10,10,10,1 +Style: Default,Arial,48,&H00FFFFFF,&H000000FF,&H00000000,&H00000000,0,0,0,0,100,100,0,0,1,2,2,8,10,10,10,1 [Events] Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text +Comment: 0,0:00:00.00,0:00:00.00,Default,,0,0,0,code once all,line_in_retime = -900; line_out_retime = 300; fade_in_dur = 200; fade_out_dur = 200 +Comment: 0,0:00:00.00,0:00:01.68,Default,,0,0,0,code once all,syl_shift = -75; syl_in_speed = 3.5; syl_out_speed = 3 +Comment: 0,0:00:03.68,0:00:05.68,Default,,0,0,0,code syl all,bord = line.styleref.outline; shadow = line.styleref.shadow; blur = 2 +Comment: 0,0:00:05.68,0:00:05.68,Default,,0,0,0,, +Comment: 0,0:00:05.68,0:00:05.68,Default,,0,0,0,, +Comment: 10,0:00:05.68,0:00:10.43,Default,,0,0,0,template char noblank all,!retime("start2syl",line_in_retime,syl_shift)!{\\fad(!fade_in_dur!,0)\\bord!bord!\shad!shadow!\\blur!blur!\\3c!line.styleref.color3!\c!line.styleref.color3!\pos($x,$y)} +Comment: 20,0:00:05.68,0:00:10.43,Default,,0,0,0,template char noblank all,!retime("start2syl",line_in_retime,syl_shift)!{\fad(200,0)\\bord0\shad0\c!line.styleref.color2!\pos($x,$y)} +Comment: 10,0:00:05.68,0:00:10.43,Default,,0,0,0,template char noblank all,!retime("syl",syl_shift,syl_shift-$dur/syl_in_speed)!{\\3c!line.styleref.color3!\c!line.styleref.color3!\\bord!bord!\shad!shadow!\\blur!blur!\move($x,$y,$x,!$y-($height/8)!,0,!($dur > syl_in_speed * 400 and 400 or $dur/syl_in_speed)!)} +Comment: 40,0:00:05.68,0:00:10.43,Default,,0,0,0,template char noblank all,!retime("syl",syl_shift,syl_shift-$dur/syl_in_speed)!{\\bord0\shad0\c!line.styleref.color2!\\t(0,!($dur > syl_in_speed * 400 and 400 or $dur/syl_in_speed)!,\c!line.styleref.color1!)\move($x,$y,$x,!$y-($height/8)!,0,!($dur > syl_in_speed * 400 and 400 or $dur/syl_in_speed)!)} +Comment: 10,0:00:05.68,0:00:10.43,Default,,0,0,0,template char noblank all,!retime("postsyl",syl_shift-$dur/syl_in_speed,($lstart + $send + ($dur/12)) < ($lend + 300) and $dur/12 or $lend - $lstart - $send + 300)!{\\3c!line.styleref.color3!\c!line.styleref.color3!\\bord!bord!\shad!shadow!\\blur!blur!\move($x,!$y-($height/8)!,$x,!$y!,0,!$dur/syl_out_speed!)!($lstart + $end + $dur/12) < ($lend + 300) and "" or "\\fad(0,150)"!} +Comment: 40,0:00:05.68,0:00:10.43,Default,,0,0,0,template char noblank all,!retime("postsyl",syl_shift-$dur/syl_in_speed,($lstart + $send + ($dur/12)) < ($lend + 300) and $dur/12 or $lend - $lstart - $send + 300)!{\\bord0\shad0\move($x,!$y-($height/8)!,$x,!$y!,0,!$dur/syl_out_speed!)!($lstart + $end + $dur/12) < ($lend + 300) and "" or "\\fad(0,150)"!} +Comment: 10,0:00:05.68,0:00:10.43,Default,,0,0,0,template char noblank all,!retime("syl2end",$dur/12,line_out_retime)!{\\fad(0,!fade_out_dur!)\\bord!bord!\shad!shadow!\\blur!blur!\\3c!line.styleref.color3!\c!line.styleref.color3!\pos($x,$y)} +Comment: 20,0:00:05.68,0:00:10.43,Default,,0,0,0,template char noblank all,!retime("syl2end",$dur/12,line_out_retime)!{\c!line.styleref.color1!\\fad(0,!fade_out_dur!)\\bord0\shad0\pos($x,$y)} +Comment: 0,0:00:05.68,0:00:05.68,Default,,0,0,0,, +Comment: 0,0:00:05.68,0:00:05.68,Default,,0,0,0,, ''' self.file.write(header) - def writeSyls(self, syl_timings, syls_per_line=10000): - last_syl_dur = 500 - syl_index = 0 - while syl_index < (len(syl_timings) - syls_per_line): - start_time = timeToDate(syl_timings[syl_index][0]) - end_time = timeToDate(syl_timings[syl_index + syls_per_line][0]) - line = f'Dialogue: 0,{start_time},{end_time},Default,,0,0,0,,' - for i in range(syl_index, syl_index + syls_per_line): - syl_dur = round((syl_timings[i+1][0] - syl_timings[i][0]) * 100) - line += f'{{\k{syl_dur:d}}}{syl_timings[i][1]:s}' + def writeSyls(self, syl_timings): + bottom = False + for syl_line in syl_timings: + start_time = timeToDate(syl_line[0][0]) + end_time = timeToDate(syl_line[-1][0]) + v_margin = (150 if bottom else 50) + line = f'Dialogue: 0,{start_time},{end_time},Default,,0,0,{v_margin:d},,' + for i in range(len(syl_line) - 1): + syl_dur = round((syl_line[i+1][0] - syl_line[i][0]) * 100) + line += f'{{\k{syl_dur:d}}}{syl_line[i][1]:s}' line += '\n' self.file.write(line) - syl_index += syls_per_line - - start_time = timeToDate(syl_timings[syl_index][0]) - end_time = timeToDate(syl_timings[-1][0] + last_syl_dur//100) - line = f'Dialogue: 0,{start_time},{end_time},Default,,0,0,0,,' - for i in range(syl_index, len(syl_timings) - 1): - syl_dur = round((syl_timings[i+1][0] - syl_timings[i][0]) * 100) - line += f'{{\k{syl_dur:d}}}{syl_timings[i][1]:s}' - line += f'{{\k{last_syl_dur:d}}}{syl_timings[-1][1]:s}\n' - - self.file.write(line) \ No newline at end of file + bottom = not bottom \ No newline at end of file diff --git a/autokara.py b/autokara.py index e7f69b174336edf669cc884c52275ae95fdea7f7..9cbf49de3b534b67cb99d13a2f7702ec4725a086 100644 --- a/autokara.py +++ b/autokara.py @@ -4,15 +4,16 @@ import demucs.separate import subprocess import shlex from pathlib import Path -from assUtils import AssWriter +from assUtils import AssWriter, getSyls -from cnn_madmom.segment import segment +from autosyl.segment import segment parser = argparse.ArgumentParser(description='AutoKara - Automatic karaoke timing tool') parser.add_argument("source_file", type=str, help="The video/audio file to time") parser.add_argument("ass_file", type=str, help="The ASS file in which to output the karaoke") parser.add_argument("--vocals", action="store_true", help="Treat the input as vocals file, i.e. do not perform vocals extraction") +parser.add_argument("--ref", help="Use an ASS file as reference") args = parser.parse_args() @@ -39,15 +40,19 @@ if not args.vocals : else: vocals_file = args.source_file +if args.ref: + reference_syls = getSyls(args.ref) +else: + reference_syls = None print("Identifying syl starts...") -onsets = segment(vocals_file) -syls = [[t, 'la'] for t in onsets] +syls = segment(vocals_file, reference_syls=reference_syls) +print(syls) print("Syls found, writing ASS file...") writer = AssWriter() writer.openAss(ass_file) writer.writeHeader() -writer.writeSyls(syls, syls_per_line=10) +writer.writeSyls(syls) writer.closeAss() diff --git a/autosyl/segment.py b/autosyl/segment.py new file mode 100644 index 0000000000000000000000000000000000000000..2f8ac129f6f758c8084d9b402a5c387edb3dea03 --- /dev/null +++ b/autosyl/segment.py @@ -0,0 +1,246 @@ +import madmom +import numpy as np +import sys +import re +import matplotlib.pyplot as plt +import scipy.signal as sg +import parselmouth + + +def segment(songfile, reference_syls=None, syls_per_line=10, last_syl_dur=500): + + delay = -4 + backtrack = False + + cnn = madmom.features.onsets.CNNOnsetProcessor() + spectral = madmom.features.onsets.SpectralOnsetProcessor('modified_kullback_leibler') + + sig = madmom.audio.signal.Signal(songfile, num_channels=1) + parsel = parselmouth.Sound(sig) + + spec = madmom.audio.spectrogram.Spectrogram(sig) + filt_spec = madmom.audio.spectrogram.FilteredSpectrogram(spec, filterbank=madmom.audio.filters.LogFilterbank, num_bands=24) + log_spec = madmom.audio.spectrogram.LogarithmicSpectrogram(filt_spec, add=1) + + magnitude = np.max(log_spec[:,:100], axis=1) + + cnn_function = cnn(sig) + spectral_function = spectral(sig) + spectral_function = spectral_function/(spectral_function.max()) + + #activation_function = 0.5*cnn_function + 0.5*spectral_function + activation_function = (2 * cnn_function * spectral_function)/(cnn_function + spectral_function) + #activation_function = np.where(spectral_function > 0.14, cnn_function, 0) + #onsets = proc(activation_function) + + activation_smoothed = madmom.audio.signal.smooth(activation_function, 20) + cnn_smoothed = madmom.audio.signal.smooth(cnn_function, 20) + onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=0.1, smooth=0) + #onsets = np.array([o for o in onsets if cnn_smoothed[o] > 0.1]) + + pitch = parsel.to_pitch() + pitch_values = pitch.selected_array['frequency'] + + pad_before = round(pitch.xs()[0]*100) + pad_after = len(magnitude) - len(pitch_values) - pad_before + + pitch_values = np.pad(pitch_values, (pad_before, pad_after), 'constant', constant_values=(0,0)) + + mask_function = magnitude * pitch_values + mask_function = mask_function/np.max(mask_function) + mask_threshold = 0.15 + mask_window = [1,6] + invalid_onsets_idx = [] + + for i in range(len(onsets)): + if np.max(mask_function[onsets[i]+mask_window[0]:onsets[i]+mask_window[1]]) < mask_threshold: + invalid_onsets_idx.append(i) + + onsets = np.delete(onsets, invalid_onsets_idx) + + + if reference_syls: + filtered_onsets = [] + line_index = 0 + for line in reference_syls: + line_index += 1 + syl_number = len(line) - 1 + line_onsets = [o for o in onsets if (o >= line[0][0] and o <= line[-1][0])] + line_onsets.sort(reverse=True, key=(lambda x: activation_smoothed[x])) + missing_syls = 0 + if syl_number > len(line_onsets): + print("WARNING : failed to detect enough onsets in line %d (%d, %d)" % (line_index, line[0][0], line[-1][0])) + missing_syls = syl_number - len(line_onsets) + filtered_onsets += line_onsets[0:syl_number] + filtered_onsets += [line[-1][0] for i in range(missing_syls)] # If missing some syllables, pad with 0-length syls + + onsets = np.array(sorted(filtered_onsets)) + + + if backtrack: + # Backtrack onsets to closest earlier local minimum + backtrack_max_frames = 50 + for i in range(len(onsets)): + initial_onset = onsets[i] + while(activation_smoothed[onsets[i] - 1] < activation_smoothed[onsets[i]] and onsets[i] > initial_onset - backtrack_max_frames): + onsets[i] -= 1 + + onsets = (onsets + delay)/100 + #print(onsets) + + if reference_syls: + syls = [] + onset_index = 0 + for line in reference_syls: + #print(onset_index, " : ", line) + l = [[onsets[onset_index + i], line[i][1]] for i in range(len(line)-1)] + l.append([line[-1][0]/100, '']) + syls.append(l) + onset_index += (len(line) - 1) + else: + syls = [] + onset_index = 0 + for l in range(0, len(onsets), syls_per_line): + if onset_index + syls_per_line < len(onsets): + line = [[onset, 'la'] for onset in onsets[onset_index:onset_index+syls_per_line]] + else: + line = [[onset, 'la'] for onset in onsets[onset_index:]] + if onset_index + syls_per_line + 1 < len(onsets): + line.append([onsets[onset_index+syls_per_line+1], '']) + else: + line.append([onsets[-1] + last_syl_dur/100, '']) + syls.append(line) + onset_index += syls_per_line + + return syls + + + +if __name__ == "__main__": + + def dateToTime(date): + """ + The `date` should be in the following format: H:MM:SS.cs + """ + hourInMinuts = int(date[0:1]) * 60 + minutsInSeconds = (int(date[2:4]) + hourInMinuts) * 60 + secondsInCentiseconds = (int(date[5:7]) + minutsInSeconds) * 100 + return int(date[8:10]) + secondsInCentiseconds + + + def getSyls(ass_file): + SYLS = [] + with open(ass_file, 'r') as f: + CONTENT = f.read() + LINES_KARA = re.compile(r"Comment:.*(\d+:\d{2}:\d{2}.\d{2}),(\d+:\d{2}:\d{2}.\d{2}),.*,karaoke,(.*)\n"); + RGX_TAGS = re.compile(r"\{\\k(\d+)\}([^\{\n\r]*)") + for line in LINES_KARA.findall(CONTENT): + syl_line = [] + lastTime = dateToTime(line[0]) + for couple in RGX_TAGS.findall(line[2]): + syl_line.append((lastTime, couple[1], int(couple[0]))) + lastTime += int(couple[0]) + syl_line.append([lastTime, '', 0]) + SYLS.append(syl_line) + return SYLS + + + + + songfile = sys.argv[1] + if(len(sys.argv) == 3): + reference_syls = getSyls(sys.argv[2]) + + #print(reference_syls) + + backtrack = False + + cnn = madmom.features.onsets.CNNOnsetProcessor() + spectral = madmom.features.onsets.SpectralOnsetProcessor('modified_kullback_leibler') + + sig = madmom.audio.signal.Signal(songfile, num_channels=1) + parsel = parselmouth.Sound(sig) + + spec = madmom.audio.spectrogram.Spectrogram(sig) + filt_spec = madmom.audio.spectrogram.FilteredSpectrogram(spec, filterbank=madmom.audio.filters.LogFilterbank, num_bands=24) + log_spec = madmom.audio.spectrogram.LogarithmicSpectrogram(filt_spec, add=1) + + magnitude = np.max(log_spec[:,:100], axis=1) + + cnn_function = cnn(sig) + spectral_function = spectral(sig) + spectral_function = spectral_function/(spectral_function.max()) + + #activation_function = 0.5*cnn_function + 0.5*spectral_function + activation_function = (2 * cnn_function * spectral_function)/(cnn_function + spectral_function) + #activation_function = np.where(spectral_function > 0.14, cnn_function, 0) + #onsets = proc(activation_function) + + activation_smoothed = madmom.audio.signal.smooth(activation_function, 20) + cnn_smoothed = madmom.audio.signal.smooth(cnn_function, 20) + onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=0.1, smooth=0) + #onsets = np.array([o for o in onsets if cnn_smoothed[o] > 0.1]) + + pitch = parsel.to_pitch() + pitch_values = pitch.selected_array['frequency'] + + pad_before = round(pitch.xs()[0]*100) + pad_after = len(magnitude) - len(pitch_values) - pad_before + + pitch_values = np.pad(pitch_values, (pad_before, pad_after), 'constant', constant_values=(0,0)) + + mask_function = magnitude * pitch_values + mask_function = mask_function/np.max(mask_function) + mask_threshold = 0.15 + mask_window = [1,6] + invalid_onsets_idx = [] + + for i in range(len(onsets)): + if np.max(mask_function[onsets[i]+mask_window[0]:onsets[i]+mask_window[1]]) < mask_threshold: + invalid_onsets_idx.append(i) + + onsets = np.delete(onsets, invalid_onsets_idx) + + + if reference_syls: + filtered_onsets = [] + line_index = 0 + for line in reference_syls: + line_index += 1 + syl_number = len(line) - 1 + line_onsets = [o for o in onsets if (o >= line[0][0] and o <= line[-1][0])] + line_onsets.sort(reverse=True, key=(lambda x: activation_smoothed[x])) + if syl_number > len(line_onsets): + print("WARNING : failed to detect enough onsets in line %d (%d, %d)" % (line_index, line[0][0], line[-1][0])) + filtered_onsets += line_onsets[0:syl_number] + + onsets = np.array(sorted(filtered_onsets)) + + # Backtrack onsets to closest earlier local minimum + if backtrack: + backtrack_max_frames = 50 + for i in range(len(onsets)): + initial_onset = onsets[i] + while(activation_smoothed[onsets[i] - 1] < activation_smoothed[onsets[i]] and onsets[i] > initial_onset - backtrack_max_frames): + onsets[i] -= 1 + + print(onsets/100) + + reference_onsets = [syl[0]+8 for line in reference_syls for syl in line[:-1]] + + fig, axs = plt.subplots(nrows=2, sharex=True) + axs[0].imshow(log_spec.T, origin='lower', aspect='auto') + axs[0].vlines(reference_onsets, 0, 140, colors='red') + axs[0].plot((pitch_values/np.max(pitch_values))*140, color='yellow') + axs[1].plot(mask_function) + #axs[1].plot(cnn_smoothed) + #axs[1].plot(spectral_function, color='green') + axs[1].plot(activation_smoothed, color='orange') + axs[1].vlines(onsets, 0, 2, colors='red') + axs[1].hlines([max(mask_threshold, 0)], 0, onsets[-1]+100, colors='black') + + #bins = np.arange(0, 1, 0.02) + #hist, hist_axs = plt.subplots(nrows=1) + #hist_axs.hist(mask_function, bins=bins) + + plt.show() diff --git a/cnn/model.py b/cnn/model.py deleted file mode 100644 index 8d3d99902e6cf74965136e0782d608ec6deb53e6..0000000000000000000000000000000000000000 --- a/cnn/model.py +++ /dev/null @@ -1,193 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -import numpy as np -from tqdm import tqdm -from cnn.music_processor import * - -""" -On the paper, -Starting from a stack of three spectrogram excerpts, -convolution and max-pooling in turns compute a set of 20 feature maps -classified with a fully-connected network. -""" - -class convNet(nn.Module): - """ - copies the neural net used in a paper. - "Improved musical onset detection with Convolutional Neural Networks". - src: https://ieeexplore.ieee.org/document/6854953 - """ - - def __init__(self): - - super(convNet, self).__init__() - # model - self.conv1 = nn.Conv2d(3, 10, (3, 7)) - self.conv2 = nn.Conv2d(10, 20, 3) - self.fc1 = nn.Linear(1120, 256) - self.fc2 = nn.Linear(256, 120) - self.fc3 = nn.Linear(120, 1) - - - def forward(self, x, istraining=False, minibatch=1): - - x = F.max_pool2d(F.relu(self.conv1(x)), (3, 1)) - x = F.max_pool2d(F.relu(self.conv2(x)), (3, 1)) - x = F.dropout(x.view(minibatch, -1), training=istraining) - x = F.dropout(F.relu(self.fc1(x)), training=istraining) - x = F.dropout(F.relu(self.fc2(x)), training=istraining) - - return F.sigmoid(self.fc3(x)) - - - def train_data_builder(self, feats, answer, major_note_index, samplerate, soundlen=15, minibatch=1, split=0.2): - """ - Args: - feats: song.feats; Audio module - answers: song.answers; Audio module - major_note_index: answer labels; corresponding to feats - samplerate: song.samplerate; Audio module - soundlen: =15. å¦ç¿’ãƒ¢ãƒ‡ãƒ«ã«æ¸¡ã™ç”»åƒãƒ‡ãƒ¼ã‚¿ã®æ¨ªæ–¹å‘ã®é•·ã•.ã“ã“ã§ã¯(80 * 15)サイズã®ãƒ‡ãƒ¼ã‚¿ã‚’使用ã—ã¦ã„ã‚‹ - minibatch: training minibatch - split: =1. - Variables: - minspace: minimum space between major note indexs - maxspace: maximum space between major note indexs - idx: index of major_note_index or feats - dist: distance of two notes - """ - - # acceptable interval in seconds - minspace = 0.1 - maxspace = 0.7 - - idx = np.random.permutation(major_note_index.shape[0] - soundlen) + soundlen // 2 - X, y = [], [] - cnt = 0 - - for i in range(int(idx.shape[0] * split)): - - dist = major_note_index[idx[i] + 1] - major_note_index[idx[i]] # distinguish by this value - - if dist < maxspace * samplerate / 512 and dist > minspace * samplerate / 512: - for j in range(-1, dist + 2): - X.append(feats[:, :, major_note_index[idx[i]] - soundlen // 2 + j : major_note_index[idx[i]] + soundlen // 2 + j + 1]) - y.append(answer[major_note_index[idx[i]] + j]) - cnt += 1 - - if cnt % minibatch == 0: - yield (torch.from_numpy(np.array(X)).float(), torch.from_numpy(np.array(y)).float()) - X, y = [], [] - - - def infer_data_builder(self, feats, soundlen=15, minibatch=1): - - x = [] - - for i in range(feats.shape[2] - soundlen): - x.append(feats[:, :, i:i+soundlen]) - - if (i + 1) % minibatch == 0: - yield (torch.from_numpy(np.array(x)).float()) - x = [] - - if len(x) != 0: - yield (torch.from_numpy(np.array(x)).float()) - - - def train(self, songs, minibatch, epoch, device, soundlen=15, val_song=None, save_place='./models/model.pth', log='./log/log.txt'): - """ - Args: - songs: the list of song - minibatch: minibatch value - epoch: number of train - device: cpu / gpu - soundlen: width of one train data's image - val_song: validation song, if you wanna validation while training, give a path of validation song data. - save_place: save place path - log: log place path - don-ka: don(1) or ka(2) or both(0), usually, firstly, train don, then, train ka. - """ - - for song in songs: - - timing = np.array([syl[0] for syl in song.timestamp]) - syllable = np.array([syl[1] for syl in song.timestamp]) - song.answer = np.zeros((song.feats.shape[2])) - - - song.major_note_index = np.rint(timing[np.where(syllable != 0)] * song.samplerate/512).astype(np.int32) - - song.major_note_index = np.delete(song.major_note_index, np.where(song.major_note_index >= song.feats.shape[2])) - - song.answer[song.major_note_index] = 1 - - song.answer = milden(song.answer) - - # training - optimizer = optim.SGD(self.parameters(), lr=0.02) - criterion = nn.MSELoss() - running_loss = 0 - val_loss = 0 - - for i in range(epoch): - for song in songs: - for X, y in self.train_data_builder(song.feats, song.answer, song.major_note_index, song.samplerate, soundlen, minibatch, split=0.2): - optimizer.zero_grad() - output = self(X.to(device), istraining=True, minibatch=minibatch) - target = y.to(device) - loss = criterion(output.squeeze(), target) - loss.backward() - optimizer.step() - running_loss += loss.data.item() - - with open(log, 'a') as f: - print("epoch: %.d running_loss: %.10f " % (i+1, running_loss), file=f) - - print("epoch: %.d running_loss: %.10f" % (i+1, running_loss)) - - running_loss = 0 - - if val_song: - inference = torch.from_numpy(self.infer(val_song.feats, device, minibatch=512)).to(device) - target = torch.from_numpy(val_song.answer[:-soundlen]).float().to(device) - loss = criterion(inference.squeeze(), target) - val_loss = loss.data.item() - - with open(log, 'a') as f: - print("val_loss: %.10f " % (val_loss), file=f) - - torch.save(self.state_dict(), save_place) - - - def infer(self, feats, device, minibatch=1): - - with torch.no_grad(): - inference = None - for x in tqdm(self.infer_data_builder(feats, minibatch=minibatch), total=feats.shape[2]//minibatch): - output = self(x.to(device), minibatch=x.shape[0]) - if inference is not None: - inference = np.concatenate((inference, output.cpu().numpy().reshape(-1))) - else: - inference = output.cpu().numpy().reshape(-1) - - return np.array(inference).reshape(-1) - - -if __name__ == '__main__': - - device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') - net = convNet() - net = net.to(device) - - with open('./data/pickles/train_data.pickle', mode='rb') as f: - songs = pickle.load(f) - - minibatch = 128 - soundlen = 15 - epoch = 100 - - - net.train(songs=songs, minibatch=minibatch, val_song=None, epoch=epoch, device=device, soundlen=soundlen, save_place='./models/model.pth', log='./data/log/log.txt') \ No newline at end of file diff --git a/cnn/music_processor.py b/cnn/music_processor.py deleted file mode 100644 index c3e8bc3ec520217655ae44b8085dd1ffdaee7ed6..0000000000000000000000000000000000000000 --- a/cnn/music_processor.py +++ /dev/null @@ -1,230 +0,0 @@ -import soundfile as sf -import matplotlib.pyplot as plt -import numpy as np -import os -from glob import glob -from scipy import signal -from scipy.fftpack import fft -from librosa.filters import mel -from librosa.display import specshow -from librosa import stft -from librosa.effects import pitch_shift -import pickle -import sys -from numba import jit, prange -from sklearn.preprocessing import normalize -import re -from assUtils import dateToTime, timeToDate - - -class Audio: - """ - audio class which holds music data and timestamp for notes. - Args: - filename: file name. - stereo: True or False; wether you have Don/Ka streo file or not. normaly True. - Variables: - Example: - >>>from music_processor import * - >>>song = Audio(filename) - >>># to get audio data - >>>song.data - >>># to import .tja files: - >>>song.import_tja(filename) - >>># to get data converted - >>>song.data = (song.data[:,0]+song.data[:,1])/2 - >>>fft_and_melscale(song, include_zero_cross=False) - """ - - def __init__(self, filename, stereo=True): - - self.data, self.samplerate = sf.read(filename, always_2d=True) - if stereo is False: - self.data = (self.data[:, 0]+self.data[:, 1])/2 - self.timestamp = [] - - - def plotaudio(self, start_t, stop_t): - - plt.plot(np.linspace(start_t, stop_t, stop_t-start_t), self.data[start_t:stop_t, 0]) - plt.show() - - - def save(self, filename="./savedmusic.wav", start_t=0, stop_t=None): - - if stop_t is None: - stop_t = self.data.shape[0] - sf.write(filename, self.data[start_t:stop_t], self.samplerate) - - - - - def import_ass(self, filename): - - with open(filename, 'r') as f: - CONTENT = f.read() - - LINES_KARA = re.compile(r"Comment:.*(\d+:\d{2}:\d{2}.\d{2}),(\d+:\d{2}:\d{2}.\d{2}),.*,karaoke,(.*)\n"); - - RGX_TAGS = re.compile(r"\{\\k(\d+)\}([^\{\n\r]*)") - - SYLS = [] - - for line in LINES_KARA.findall(CONTENT): - lastTime = dateToTime(line[0]) - for couple in RGX_TAGS.findall(line[2]): - self.timestamp.append((lastTime/100, 1 if len(couple[1]) > 0 else 0)) - lastTime += int(couple[0]) - self.timestamp = np.array(self.timestamp, dtype='float, int') - - -def make_frame(data, nhop, nfft): - """ - helping function for fftandmelscale. - ç´°ã‹ã„時間ã«åˆ‡ã‚Šåˆ†ã‘ãŸã‚‚ã®ã‚’å¦ç¿’データã¨ã™ã‚‹ãŸã‚,nhop(512)ãšã¤ãšã‚‰ã—ãªãŒã‚‰nfftサイズã®ãƒ‡ãƒ¼ã‚¿ã‚’é…列ã¨ã—ã¦è¿”ã™ - """ - - length = data.shape[0] - framedata = np.concatenate((data, np.zeros(nfft))) # zero padding - return np.array([framedata[i*nhop:i*nhop+nfft] for i in range(length//nhop)]) - - -#@jit -def fft_and_melscale(song, nhop=512, nffts=[1024, 2048, 4096], mel_nband=80, mel_freqlo=27.5, mel_freqhi=16000.0, include_zero_cross=False): - """ - fft and melscale method. - fft: nfft = [1024, 2048, 4096]; サンプルã®åˆ‡ã‚Šå–ã‚‹é•·ã•を変ãˆãªãŒã‚‰ãƒ‡ãƒ¼ã‚¿ã‹ã‚‰np.arrayを抽出ã—ã¦é«˜é€Ÿãƒ•ーリエ変æ›ã‚’行ã†ï¼Ž - melscale: å‘¨æ³¢æ•°ã®æ¬¡å…ƒã‚’削減ã™ã‚‹ã¨ã¨ã‚‚ã«ï¼Œlog10ã®å€¤ã‚’å–ã£ã¦ã„る. - """ - - feat_channels = [] - - for nfft in nffts: - - feats = [] - window = signal.blackmanharris(nfft) - filt = mel(sr=song.samplerate, n_fft=nfft, n_mels=mel_nband, fmin=mel_freqlo, fmax=mel_freqhi) - - # get normal frame - frame = make_frame(song.data, nhop, nfft) - # print(frame.shape) - - # melscaling - processedframe = fft(window*frame)[:, :nfft//2+1] - processedframe = np.dot(filt, np.transpose(np.abs(processedframe)**2)) - processedframe = 20*np.log10(processedframe+0.1) - # print(processedframe.shape) - - feat_channels.append(processedframe) - - if include_zero_cross: - song.zero_crossing = np.where(np.diff(np.sign(song.data)))[0] - print(song.zero_crossing) - - return np.array(feat_channels) - - -#@jit(parallel=True) -def multi_fft_and_melscale(songs, nhop=512, nffts=[1024, 2048, 4096], mel_nband=80, mel_freqlo=27.5, mel_freqhi=16000.0, include_zero_cross=False): - - for i in prange(len(songs)): - songs[i].feats = fft_and_melscale(songs[i], nhop, nffts, mel_nband, mel_freqlo, mel_freqhi) - - -def milden(data): - """put smaller value(0.25) to plus minus 1 frame.""" - - for i in range(data.shape[0]): - - if data[i] == 1: - if i > 0: - data[i-1] = 0.25 - if i < data.shape[0] - 1: - data[i+1] = 0.25 - - if data[i] == 0.26: - if i > 0: - data[i-1] = 0.1 - if i < data.shape[0] - 1: - data[i+1] = 0.1 - - return data - - -def smooth(x, window_len=11, window='hanning'): - - if x.ndim != 1: - raise ValueError - - if x.size < window_len: - raise ValueError - - if window_len < 3: - return x - - if not window in ['flat', 'hanning', 'hamming', 'bartlett', 'blackman']: - raise ValueError - - s = np.r_[x[window_len-1:0:-1], x, x[-2:-window_len-1:-1]] - # print(len(s)) - if window == 'flat': # moving average - w = np.ones(window_len, 'd') - else: - w = eval('np.'+window+'(window_len)') - - y = np.convolve(w/w.sum(), s, mode='valid') - - return y - - - - -def music_for_train(serv, deletemusic=True, verbose=False, nhop=512, nffts=[1024, 2048, 4096], mel_nband=80, mel_freqlo=27.5, mel_freqhi=16000.0, include_zero_cross=False): - - songplaces = glob(serv) - songs = [] - - for songplace in songplaces: - - if verbose: - print(songplace) - - songname = songplace.split("/")[-1] - - song = Audio(glob(songplace+"/*.ogg")[0]) - song.import_ass(glob(songplace+"/*.ass")[-1]) - song.data = (song.data[:, 0]+song.data[:, 1])/2 - - song.feats = fft_and_melscale(song, nhop, nffts, mel_nband, mel_freqlo, mel_freqhi) - - if deletemusic: - song.data = None - - with open(f'./data/pickles/{songname:s}.pickle', mode='wb') as f: - pickle.dump(song, f) - - -def music_for_test(serv, deletemusic=True, verbose=False): - - song = Audio(glob(serv+"/*.ogg")[0], stereo=False) - # song.import_tja(glob(serv+"/*.tja")[-1]) - song.feats = fft_and_melscale(song, include_zero_cross=False) - with open('./data/pickles/test_data.pickle', mode='wb') as f: - pickle.dump(song, f) - - -if __name__ == "__main__": - - if sys.argv[1] == 'train': - print("preparing all train data processing...") - serv = f'./{sys.argv[2]:s}/*' - music_for_train(serv, verbose=True) - print("all train data processing done!") - - if sys.argv[1] == 'test': - print("test data proccesing...") - serv = f'./{sys.argv[2]:s}/*' - music_for_test(serv) - print("test data processing done!") - - diff --git a/cnn/segment.py b/cnn/segment.py deleted file mode 100644 index 13f293ce77eaa080500a8282b228b249720c31eb..0000000000000000000000000000000000000000 --- a/cnn/segment.py +++ /dev/null @@ -1,60 +0,0 @@ -from cnn.model import * -from cnn.music_processor import * -from assUtils import AssWriter -import pickle -import numpy as np -from scipy.signal import argrelmax -from librosa.util import peak_pick -from librosa.onset import onset_detect - - -def segment(songfile): - - - song = Audio(songfile, stereo=False) - song.feats = fft_and_melscale(song, include_zero_cross=False) - - device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') - net = convNet() - net = net.to(device) - - if torch.cuda.is_available(): - net.load_state_dict(torch.load('./models/model.pth')) - else: - net.load_state_dict(torch.load('./models/model.pth', map_location='cpu')) - - inference = net.infer(song.feats, device, minibatch=4192) - inference = np.reshape(inference, (-1)) - - return detection(inference, song.samplerate) - - - -def detection(inference, samplerate): - - inference = smooth(inference, 5) - - - timestamp = (peak_pick(inference, pre_max=1, post_max=2, pre_avg=4, post_avg=5, delta=0.05, wait=3)) # 実際ã¯7フレーム目ã®ã¨ã“ã‚ã®éŸ³ - - timestamp = timestamp*512/samplerate - - return timestamp - - - -if __name__ == '__main__': - - onsets = segment(sys.argv[1]) - syls = [[t, ''] for t in onsets] - - print(syls) - - writer = AssWriter() - writer.openAss("./media/test.ass") - writer.writeHeader() - writer.writeSyls(syls) - writer.closeAss() - - - diff --git a/cnn_madmom/segment.py b/cnn_madmom/segment.py deleted file mode 100644 index a86921df036ac72943a271da2440cbfae1b1e0a4..0000000000000000000000000000000000000000 --- a/cnn_madmom/segment.py +++ /dev/null @@ -1,96 +0,0 @@ -import madmom -import numpy as np -import sys -import matplotlib.pyplot as plt -from scipy.ndimage.filters import maximum_filter - - -def segment(songfile): - - delay = -4 - backtrack = False - - cnn = madmom.features.onsets.CNNOnsetProcessor() - spectral = madmom.features.onsets.SpectralOnsetProcessor('complex_domain') - - - spec = spec = madmom.audio.spectrogram.Spectrogram(songfile, num_channels=1) - filt_spec = madmom.audio.spectrogram.FilteredSpectrogram(spec, filterbank=madmom.audio.filters.LogFilterbank, num_bands=24) - log_spec = madmom.audio.spectrogram.LogarithmicSpectrogram(filt_spec, add=1) - - cnn_function = cnn(songfile, num_channels=1) - spectral_function = spectral(songfile, num_channels=1) - spectral_function = spectral_function/(spectral_function.max()) - - #activation_function = 0.5*cnn_function + 0.5*spectral_function - activation_function = (2 * cnn_function * spectral_function)/(cnn_function + spectral_function) - #activation_function = np.where(spectral_function > 0.14, cnn_function, 0) - #onsets = proc(activation_function) - - activation_smoothed = madmom.audio.signal.smooth(activation_function, 20) - cnn_smoothed = madmom.audio.signal.smooth(cnn_function, 20) - onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=0.6, smooth=0) - onsets = np.array([o for o in onsets if cnn_smoothed[o] > 0.2]) - - - if backtrack: - # Backtrack onsets to closest earlier local minimum - backtrack_max_frames = 50 - for i in range(len(onsets)): - initial_onset = onsets[i] - while(activation_smoothed[onsets[i] - 1] < activation_smoothed[onsets[i]] and onsets[i] > initial_onset - backtrack_max_frames): - onsets[i] -= 1 - - onsets = (onsets + delay)/100 - - print(onsets) - - return onsets - - - -if __name__ == "__main__": - songfile = sys.argv[1] - - backtrack = False - - cnn = madmom.features.onsets.CNNOnsetProcessor() - spectral = madmom.features.onsets.SpectralOnsetProcessor('modified_kullback_leibler') - - - spec = spec = madmom.audio.spectrogram.Spectrogram(songfile, num_channels=1) - filt_spec = madmom.audio.spectrogram.FilteredSpectrogram(spec, filterbank=madmom.audio.filters.LogFilterbank, num_bands=24) - log_spec = madmom.audio.spectrogram.LogarithmicSpectrogram(filt_spec, add=1) - - cnn_function = cnn(songfile, num_channels=1) - spectral_function = spectral(songfile, num_channels=1) - spectral_function = spectral_function/(spectral_function.max()) - - #activation_function = 0.5*cnn_function + 0.5*spectral_function - activation_function = (2 * cnn_function * spectral_function)/(cnn_function + spectral_function) - #activation_function = np.where(spectral_function > 0.14, cnn_function, 0) - #onsets = proc(activation_function) - - activation_smoothed = madmom.audio.signal.smooth(activation_function, 20) - cnn_smoothed = madmom.audio.signal.smooth(cnn_function, 20) - onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=0.6, smooth=0) - onsets = np.array([o for o in onsets if cnn_smoothed[o] > 0.2]) - - # Backtrack onsets to closest earlier local minimum - if backtrack: - backtrack_max_frames = 50 - for i in range(len(onsets)): - initial_onset = onsets[i] - while(activation_smoothed[onsets[i] - 1] < activation_smoothed[onsets[i]] and onsets[i] > initial_onset - backtrack_max_frames): - onsets[i] -= 1 - - print(onsets/100) - - fig, axs = plt.subplots(nrows=2, sharex=True) - axs[0].imshow(log_spec.T, origin='lower', aspect='auto') - axs[1].plot(cnn_smoothed) - axs[1].plot(spectral_function, color='green') - axs[1].plot(activation_smoothed, color='orange') - axs[1].vlines(onsets, 0, 1, colors='red') - - plt.show() diff --git a/cnn_prepare_data.py b/cnn_prepare_data.py deleted file mode 100644 index 81b96fad6704937882dfb66815d6d95f2e837a8a..0000000000000000000000000000000000000000 --- a/cnn_prepare_data.py +++ /dev/null @@ -1,15 +0,0 @@ -from cnn.music_processor import * - - - -if sys.argv[1] == 'train': - print("preparing all train data processing...") - serv = f'./{sys.argv[2]:s}/*' - music_for_train(serv, verbose=True) - print("all train data processing done!") - -if sys.argv[1] == 'test': - print("test data proccesing...") - serv = f'./{sys.argv[2]:s}/*' - music_for_test(serv) - print("test data processing done!") \ No newline at end of file diff --git a/cnn_train.py b/cnn_train.py deleted file mode 100644 index b6325379525cdf6035c656e77501cca05c4de986..0000000000000000000000000000000000000000 --- a/cnn_train.py +++ /dev/null @@ -1,23 +0,0 @@ -from cnn.model import * -from cnn.music_processor import * -from glob import glob - - -device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') -net = convNet() -net = net.to(device) - -songplaces = glob('./data/pickles/*.pickle') -songs = [] - -for songplace in songplaces: - with open(songplace, mode='rb') as f: - song = pickle.load(f) - songs.append(song) - -minibatch = 128 -soundlen = 15 -epoch = 100 - - -net.train(songs=songs, minibatch=minibatch, val_song=None, epoch=epoch, device=device, soundlen=soundlen, save_place='./models/model.pth', log='./data/log/log.txt') \ No newline at end of file diff --git a/process_train_data.sh b/preprocess_media.sh similarity index 100% rename from process_train_data.sh rename to preprocess_media.sh diff --git a/requirements.txt b/requirements.txt index 770a61b27edbe79b68aa14593a431c0a841062bf..037cec811b5956b459fe862eae3b0b82dfc61d45 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,5 @@ tqdm scipy cython mido -git+https://github.com/CPJKU/madmom.git \ No newline at end of file +git+https://github.com/CPJKU/madmom.git +praat-parselmouth \ No newline at end of file diff --git a/rosa/segment.py b/rosa/segment.py deleted file mode 100644 index f455a22f2e50c28e51802752254cd8d6e217645a..0000000000000000000000000000000000000000 --- a/rosa/segment.py +++ /dev/null @@ -1,58 +0,0 @@ -import librosa -import numpy as np -# import matplotlib.pyplot as plt -import sys - - - -class Segment: - - def __init__(self, file): - self.file = file - - - def onsets(self): - ''' - Use librosa's onset detection to detect syllable start times - ''' - - y, sr = librosa.load(self.file) - - o_env = librosa.onset.onset_strength(y=y, sr=sr) - times = librosa.times_like(o_env, sr=sr) - onset_raw = librosa.onset.onset_detect(onset_envelope=o_env, sr=sr) - onset_bt = librosa.onset.onset_backtrack(onset_raw, o_env) - - S = np.abs(librosa.stft(y=y)) - rms = librosa.feature.rms(S=S) - onset_bt_rms = librosa.onset.onset_backtrack(onset_raw, rms[0]) - - onset_bt_times = librosa.frames_to_time(onset_bt, sr=sr) - onset_bt_rms_times = librosa.frames_to_time(onset_bt_rms, sr=sr) - - onset_raw_times = librosa.frames_to_time(onset_raw, sr=sr) - - # print(onset_bt_rms_times) - - ''' - fig, ax = plt.subplots(nrows=3, sharex=True) - librosa.display.specshow(librosa.amplitude_to_db(S, ref=np.max),y_axis='log', x_axis='time', ax=ax[0]) - ax[0].label_outer() - ax[1].plot(times, o_env, label='Onset strength') - ax[1].vlines(librosa.frames_to_time(onset_raw), 0, o_env.max(), label='Raw onsets') - ax[1].vlines(librosa.frames_to_time(onset_bt), 0, o_env.max(), label='Backtracked', color='r') - ax[1].legend() - ax[1].label_outer() - ax[2].plot(times, rms[0], label='RMS') - ax[2].vlines(librosa.frames_to_time(onset_bt_rms), 0, rms.max(), label='Backtracked (RMS)', color='r') - ax[2].legend() - - plt.show() - ''' - - return onset_raw_times - - -if __name__ == "__main__": - seg = Segment(sys.argv[1]) - seg.onsets() \ No newline at end of file