diff --git a/.gitignore b/.gitignore index d5283ab6a3cd51595be63e66e80e56a621715c5f..1504b226c7ba1d4903cb5fca38fc9edf28e7c04f 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,6 @@ !autokara.py !segment.py !assUtils.py +!music_processor.py +!model.py media/ \ No newline at end of file diff --git a/assUtils.py b/assUtils.py index 610081f7cf6cafd1d836c8e4e7c1d994d133ed8d..510e639f8c31ea0d41365f84233cf7abffabe374 100644 --- a/assUtils.py +++ b/assUtils.py @@ -11,6 +11,15 @@ def timeToDate(time): return f'{hours:02d}:{remainder_mins:02d}:{remainder_sec:.2f}' +def dateToTime(date): + """ + The `date` should be in the following format: H:MM:SS.cs + """ + hourInMinuts = int(date[0:1]) * 60 + minutsInSeconds = (int(date[2:4]) + hourInMinuts) * 60 + secondsInCentiseconds = (int(date[5:7]) + minutsInSeconds) * 100 + return int(date[8:10]) + secondsInCentiseconds + class AssWriter: @@ -44,12 +53,12 @@ Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text def writeSyls(self, syl_timings): last_syl_dur = 500 - start_time = timeToDate(syl_timings[0]) - end_time = timeToDate(syl_timings[-1] + last_syl_dur//100) + start_time = timeToDate(syl_timings[0][0]) + end_time = timeToDate(syl_timings[-1][0] + last_syl_dur//100) line = f'Dialogue: 0,{start_time},{end_time},Default,,0,0,0,,' for i in range(len(syl_timings) - 1): - syl_dur = round((syl_timings[i+1] - syl_timings[i]) * 100) - line += f'{{\k{syl_dur:d}}}' - line += f'{{\k{last_syl_dur:d}}}\n' + syl_dur = round((syl_timings[i+1][0] - syl_timings[i][0]) * 100) + line += f'{{\k{syl_dur:d}}}{syl_timings[i][1]:s}' + line += f'{{\k{last_syl_dur:d}}}{syl_timings[-1][1]:s}\n' self.file.write(line) \ No newline at end of file diff --git a/karaUtils.py b/karaUtils.py deleted file mode 100755 index c47d07e1e3d5d1ce4fbd8a70f49e44a05487506c..0000000000000000000000000000000000000000 --- a/karaUtils.py +++ /dev/null @@ -1,34 +0,0 @@ -#!/usr/bin/env python3 -import re -import sys - -try: - FILE = sys.argv[1] -except IndexError: - print("usage : %s inputFile.py" % sys.argv[0]) - -with open(FILE, 'r') as f: - CONTENT = f.read() - -LINES_KARA = re.compile(r"Comment:.*(\d+:\d{2}:\d{2}.\d{2}),(\d+:\d{2}:\d{2}.\d{2}),.*,karaoke,(.*)\n"); - -RGX_TAGS = re.compile(r"\{\\k(\d+)\}([^\{\n\r]*)") - -SYLS = [] - -def dateToTime(date): - """ - The `date` should be in the following format: HH:MM:SS.cs - """ - hourInMinuts = int(date[0:1]) * 60 - minutsInSeconds = (int(date[3:4]) + hourInMinuts) * 60 - secondsInCentiseconds = (int(date[6:7]) + minutsInSeconds) * 100 - return int(date[9:10]) + secondsInCentiseconds - -for line in LINES_KARA.findall(CONTENT): - lastTime = dateToTime(line[0]) - for couple in RGX_TAGS.findall(line[2]): - SYLS.append((lastTime, couple[1], couple[0])) - lastTime += int(couple[0]) - -print(SYLS) diff --git a/model.py b/model.py new file mode 100644 index 0000000000000000000000000000000000000000..3715c5cc1dbae6821e1351362ac55736dbb64e95 --- /dev/null +++ b/model.py @@ -0,0 +1,217 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import numpy as np +from tqdm import tqdm +from music_processor import * + +""" +On the paper, +Starting from a stack of three spectrogram excerpts, +convolution and max-pooling in turns compute a set of 20 feature maps +classified with a fully-connected network. +""" + +class convNet(nn.Module): + """ + copies the neural net used in a paper. + "Improved musical onset detection with Convolutional Neural Networks". + src: https://ieeexplore.ieee.org/document/6854953 + """ + + def __init__(self): + + super(convNet, self).__init__() + # model + self.conv1 = nn.Conv2d(3, 10, (3, 7)) + self.conv2 = nn.Conv2d(10, 20, 3) + self.fc1 = nn.Linear(1120, 256) + self.fc2 = nn.Linear(256, 120) + self.fc3 = nn.Linear(120, 1) + + + def forward(self, x, istraining=False, minibatch=1): + + x = F.max_pool2d(F.relu(self.conv1(x)), (3, 1)) + x = F.max_pool2d(F.relu(self.conv2(x)), (3, 1)) + x = F.dropout(x.view(minibatch, -1), training=istraining) + x = F.dropout(F.relu(self.fc1(x)), training=istraining) + x = F.dropout(F.relu(self.fc2(x)), training=istraining) + + return F.sigmoid(self.fc3(x)) + + + def train_data_builder(self, feats, answer, major_note_index, samplerate, soundlen=15, minibatch=1, split=0.2): + """ + Args: + feats: song.feats; Audio module + answers: song.answers; Audio module + major_note_index: answer labels; corresponding to feats + samplerate: song.samplerate; Audio module + soundlen: =15. å¦ç¿’ãƒ¢ãƒ‡ãƒ«ã«æ¸¡ã™ç”»åƒãƒ‡ãƒ¼ã‚¿ã®æ¨ªæ–¹å‘ã®é•·ã•.ã“ã“ã§ã¯(80 * 15)サイズã®ãƒ‡ãƒ¼ã‚¿ã‚’使用ã—ã¦ã„ã‚‹ + minibatch: training minibatch + split: =1. + Variables: + minspace: minimum space between major note indexs + maxspace: maximum space between major note indexs + idx: index of major_note_index or feats + dist: distance of two notes + """ + + # acceptable interval in seconds + minspace = 0.1 + maxspace = 0.7 + + idx = np.random.permutation(major_note_index.shape[0] - soundlen) + soundlen // 2 + X, y = [], [] + cnt = 0 + + for i in range(int(idx.shape[0] * split)): + + dist = major_note_index[idx[i] + 1] - major_note_index[idx[i]] # distinguish by this value + + if dist < maxspace * samplerate / 512 and dist > minspace * samplerate / 512: + for j in range(-1, dist + 2): + X.append(feats[:, :, major_note_index[idx[i]] - soundlen // 2 + j : major_note_index[idx[i]] + soundlen // 2 + j + 1]) + y.append(answer[major_note_index[idx[i]] + j]) + cnt += 1 + + if cnt % minibatch == 0: + yield (torch.from_numpy(np.array(X)).float(), torch.from_numpy(np.array(y)).float()) + X, y = [], [] + + + def infer_data_builder(self, feats, soundlen=15, minibatch=1): + + x = [] + + for i in range(feats.shape[2] - soundlen): + x.append(feats[:, :, i:i+soundlen]) + + if (i + 1) % minibatch == 0: + yield (torch.from_numpy(np.array(x)).float()) + x = [] + + if len(x) != 0: + yield (torch.from_numpy(np.array(x)).float()) + + + def train(self, songs, minibatch, epoch, device, soundlen=15, val_song=None, save_place='./models/model.pth', log='./log/log.txt', don_ka=0): + """ + Args: + songs: the list of song + minibatch: minibatch value + epoch: number of train + device: cpu / gpu + soundlen: width of one train data's image + val_song: validation song, if you wanna validation while training, give a path of validation song data. + save_place: save place path + log: log place path + don-ka: don(1) or ka(2) or both(0), usually, firstly, train don, then, train ka. + """ + + for song in songs: + + timing = song.timestamp[:, 0] + sound = song.timestamp[:, 1] + song.answer = np.zeros((song.feats.shape[2])) + + if don_ka == 0: + song.major_note_index = np.rint(timing[np.where(sound != 0)] * song.samplerate/512).astype(np.int32) + else: + song.major_note_index = np.rint(timing[np.where(sound == don_ka)] * song.samplerate/512).astype(np.int32) + song.minor_note_index = np.rint(timing[np.where(sound == 3-don_ka)] * song.samplerate/512).astype(np.int32) + + song.major_note_index = np.delete(song.major_note_index, np.where(song.major_note_index >= song.feats.shape[2])) + song.minor_note_index = np.delete(song.minor_note_index, np.where(song.minor_note_index >= song.feats.shape[2])) + song.answer[song.major_note_index] = 1 + song.answer[song.minor_note_index] = 0.26 + song.answer = milden(song.answer) + + if val_song: + + timing = val_song.timestamp[:, 0] + sound = val_song.timestamp[:, 1] + val_song.answer = np.zeros((val_song.feats.shape[2])) + + if don_ka == 0: + val_song.major_note_index = np.rint(timing[np.where(sound != 0)] * val_song.samplerate/512).astype(np.int32) + else: + val_song.major_note_index = np.rint(timing[np.where(sound == don_ka)] * val_song.samplerate/512).astype(np.int32) + val_song.minor_note_index = np.rint(timing[np.where(sound == 3-don_ka)] * val_song.samplerate/512).astype(np.int32) + + val_song.major_note_index = np.delete(val_song.major_note_index, np.where(val_song.major_note_index >= val_song.feats.shape[2])) + val_song.minor_note_index = np.delete(val_song.minor_note_index, np.where(val_song.minor_note_index >= val_song.feats.shape[2])) + val_song.answer[val_song.major_note_index] = 1 + val_song.answer[val_song.minor_note_index] = 0.26 + val_song.answer = milden(val_song.answer) + + # training + optimizer = optim.SGD(self.parameters(), lr=0.02) + criterion = nn.MSELoss() + running_loss = 0 + val_loss = 0 + + for i in range(epoch): + for song in songs: + for X, y in self.train_data_builder(song.feats, song.answer, song.major_note_index, song.samplerate, soundlen, minibatch, split=0.2): + optimizer.zero_grad() + output = self(X.to(device), istraining=True, minibatch=minibatch) + target = y.to(device) + loss = criterion(output.squeeze(), target) + loss.backward() + optimizer.step() + running_loss += loss.data.item() + + with open(log, 'a') as f: + print("epoch: %.d running_loss: %.10f " % (i+1, running_loss), file=f) + + print("epoch: %.d running_loss: %.10f" % (i+1, running_loss)) + + running_loss = 0 + + if val_song: + inference = torch.from_numpy(self.infer(val_song.feats, device, minibatch=512)).to(device) + target = torch.from_numpy(val_song.answer[:-soundlen]).float().to(device) + loss = criterion(inference.squeeze(), target) + val_loss = loss.data.item() + + with open(log, 'a') as f: + print("val_loss: %.10f " % (val_loss), file=f) + + torch.save(self.state_dict(), save_place) + + + def infer(self, feats, device, minibatch=1): + + with torch.no_grad(): + inference = None + for x in tqdm(self.infer_data_builder(feats, minibatch=minibatch), total=feats.shape[2]//minibatch): + output = self(x.to(device), minibatch=x.shape[0]) + if inference is not None: + inference = np.concatenate((inference, output.cpu().numpy().reshape(-1))) + else: + inference = output.cpu().numpy().reshape(-1) + + return np.array(inference).reshape(-1) + + +if __name__ == '__main__': + + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + net = convNet() + net = net.to(device) + + with open('./data/pickles/train_data.pickle', mode='rb') as f: + songs = pickle.load(f) + + minibatch = 128 + soundlen = 15 + epoch = 100 + + if sys.argv[1] == 'don': + net.train(songs=songs, minibatch=minibatch, val_song=None, epoch=epoch, device=device, soundlen=soundlen, save_place='./models/don_model.pth', log='./data/log/don.txt', don_ka=1) + + if sys.argv[1] == 'ka': + net.train(songs=songs, minibatch=minibatch, val_song=None, epoch=epoch, device=device, soundlen=soundlen, save_place='./models/ka_model.pth', log='./data/log/ka.txt', don_ka=2) \ No newline at end of file diff --git a/music_processor.py b/music_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..0ba18d96187ee07403705e4ce008368113241b4a --- /dev/null +++ b/music_processor.py @@ -0,0 +1,345 @@ +import soundfile as sf +import matplotlib.pyplot as plt +import numpy as np +import os +from glob import glob +from scipy import signal +from scipy.fftpack import fft +from librosa.filters import mel +from librosa.display import specshow +from librosa import stft +from librosa.effects import pitch_shift +import pickle +import sys +from numba import jit, prange +from sklearn.preprocessing import normalize +import re +from assUtils import dateToTime, timeToDate + + +class Audio: + """ + audio class which holds music data and timestamp for notes. + Args: + filename: file name. + stereo: True or False; wether you have Don/Ka streo file or not. normaly True. + Variables: + Example: + >>>from music_processor import * + >>>song = Audio(filename) + >>># to get audio data + >>>song.data + >>># to import .tja files: + >>>song.import_tja(filename) + >>># to get data converted + >>>song.data = (song.data[:,0]+song.data[:,1])/2 + >>>fft_and_melscale(song, include_zero_cross=False) + """ + + def __init__(self, filename, stereo=True): + + self.data, self.samplerate = sf.read(filename, always_2d=True) + if stereo is False: + self.data = (self.data[:, 0]+self.data[:, 1])/2 + self.timestamp = [] + + + def plotaudio(self, start_t, stop_t): + + plt.plot(np.linspace(start_t, stop_t, stop_t-start_t), self.data[start_t:stop_t, 0]) + plt.show() + + + def save(self, filename="./savedmusic.wav", start_t=0, stop_t=None): + + if stop_t is None: + stop_t = self.data.shape[0] + sf.write(filename, self.data[start_t:stop_t], self.samplerate) + + + def import_tja(self, filename, verbose=False, diff=False, difficulty=None): + """imports tja file and convert it into timestamp""" + + now = 0.0 + bpm = 100 + measure = [4, 4] # hyousi + self.timestamp = [] + skipflag = False + + with open(filename, "rb") as f: + while True: + line = f.readline() + try: + line = line.decode('sjis') + except UnicodeDecodeError: + line = line.decode('utf-8') + if line.find('//') != -1: + line = line[:line.find('//')] + if line[0:5] == "TITLE": + if verbose: + print("importing: ", line[6:]) + elif line[0:6] == "OFFSET": + now = -float(line[7:-2]) + elif line[0:4] == "BPM:": + bpm = float(line[4:-2]) + if line[0:6] == "COURSE": + if difficulty and difficulty > 0: + skipflag = True + difficulty -= 1 + elif line == "#START\r\n": + if skipflag: + skipflag = False + continue + break + + sound = [] + while True: + line = f.readline() + # print(line) + try: + line = line.decode('sjis') + except UnicodeDecodeError: + line = line.decode('utf-8') + + if line.find('//') != -1: + line = line[:line.find('//')] + if line[0] <= '9' and line[0] >= '0': + if line.find(',') != -1: + sound += line[0:line.find(',')] + beat = len(sound) + for i in range(beat): + if diff: + if int(sound[i]) in (1, 3, 5, 6, 7): + self.timestamp.append([int(100*(now+i*60*measure[0]/bpm/beat))/100, 1]) + elif int(sound[i]) in (2, 4): + self.timestamp.append([int(100*(now+i*60*measure[0]/bpm/beat))/100, 2]) + else: + if int(sound[i]) != 0: + self.timestamp.append([int(100*(now+i*60*measure[0]/bpm/beat))/100, int(sound[i])]) + now += 60/bpm*measure[0] + sound = [] + else: + sound += line[0:-2] + elif line[0] == ',': + now += 60/bpm*measure[0] + elif line[0:10] == '#BPMCHANGE': + bpm = float(line[11:-2]) + elif line[0:8] == '#MEASURE': + measure[0] = int(line[line.find('/')-1]) + measure[1] = int(line[line.find('/')+1]) + elif line[0:6] == '#DELAY': + now += float(line[7:-2]) + elif line[0:4] == "#END": + if(verbose): + print("import complete!") + self.timestamp = np.array(self.timestamp) + break + + + + def import_ass(self, filename): + + with open(filename, 'r') as f: + CONTENT = f.read() + + LINES_KARA = re.compile(r"Comment:.*(\d+:\d{2}:\d{2}.\d{2}),(\d+:\d{2}:\d{2}.\d{2}),.*,karaoke,(.*)\n"); + + RGX_TAGS = re.compile(r"\{\\k(\d+)\}([^\{\n\r]*)") + + SYLS = [] + + for line in LINES_KARA.findall(CONTENT): + lastTime = dateToTime(line[0]) + for couple in RGX_TAGS.findall(line[2]): + self.timestamp.append((lastTime/100, couple[1])) + lastTime += int(couple[0]) + print(type(lastTime)) + print(self.timestamp) + self.timestamp = np.array(self.timestamp, dtype='float, object') + + +def make_frame(data, nhop, nfft): + """ + helping function for fftandmelscale. + ç´°ã‹ã„時間ã«åˆ‡ã‚Šåˆ†ã‘ãŸã‚‚ã®ã‚’å¦ç¿’データã¨ã™ã‚‹ãŸã‚,nhop(512)ãšã¤ãšã‚‰ã—ãªãŒã‚‰nfftサイズã®ãƒ‡ãƒ¼ã‚¿ã‚’é…列ã¨ã—ã¦è¿”ã™ + """ + + length = data.shape[0] + framedata = np.concatenate((data, np.zeros(nfft))) # zero padding + return np.array([framedata[i*nhop:i*nhop+nfft] for i in range(length//nhop)]) + + +@jit +def fft_and_melscale(song, nhop=512, nffts=[1024, 2048, 4096], mel_nband=80, mel_freqlo=27.5, mel_freqhi=16000.0, include_zero_cross=False): + """ + fft and melscale method. + fft: nfft = [1024, 2048, 4096]; サンプルã®åˆ‡ã‚Šå–ã‚‹é•·ã•を変ãˆãªãŒã‚‰ãƒ‡ãƒ¼ã‚¿ã‹ã‚‰np.arrayを抽出ã—ã¦é«˜é€Ÿãƒ•ーリエ変æ›ã‚’行ã†ï¼Ž + melscale: å‘¨æ³¢æ•°ã®æ¬¡å…ƒã‚’削減ã™ã‚‹ã¨ã¨ã‚‚ã«ï¼Œlog10ã®å€¤ã‚’å–ã£ã¦ã„る. + """ + + feat_channels = [] + + for nfft in nffts: + + feats = [] + window = signal.blackmanharris(nfft) + filt = mel(song.samplerate, nfft, mel_nband, mel_freqlo, mel_freqhi) + + # get normal frame + frame = make_frame(song.data, nhop, nfft) + # print(frame.shape) + + # melscaling + processedframe = fft(window*frame)[:, :nfft//2+1] + processedframe = np.dot(filt, np.transpose(np.abs(processedframe)**2)) + processedframe = 20*np.log10(processedframe+0.1) + # print(processedframe.shape) + + feat_channels.append(processedframe) + + if include_zero_cross: + song.zero_crossing = np.where(np.diff(np.sign(song.data)))[0] + print(song.zero_crossing) + + return np.array(feat_channels) + + +@jit(parallel=True) +def multi_fft_and_melscale(songs, nhop=512, nffts=[1024, 2048, 4096], mel_nband=80, mel_freqlo=27.5, mel_freqhi=16000.0, include_zero_cross=False): + + for i in prange(len(songs)): + songs[i].feats = fft_and_melscale(songs[i], nhop, nffts, mel_nband, mel_freqlo, mel_freqhi) + + +def milden(data): + """put smaller value(0.25) to plus minus 1 frame.""" + + for i in range(data.shape[0]): + + if data[i] == 1: + if i > 0: + data[i-1] = 0.25 + if i < data.shape[0] - 1: + data[i+1] = 0.25 + + if data[i] == 0.26: + if i > 0: + data[i-1] = 0.1 + if i < data.shape[0] - 1: + data[i+1] = 0.1 + + return data + + +def smooth(x, window_len=11, window='hanning'): + + if x.ndim != 1: + raise ValueError + + if x.size < window_len: + raise ValueError + + if window_len < 3: + return x + + if not window in ['flat', 'hanning', 'hamming', 'bartlett', 'blackman']: + raise ValueError + + s = np.r_[x[window_len-1:0:-1], x, x[-2:-window_len-1:-1]] + # print(len(s)) + if window == 'flat': # moving average + w = np.ones(window_len, 'd') + else: + w = eval('np.'+window+'(window_len)') + + y = np.convolve(w/w.sum(), s, mode='valid') + + return y + + + +def music_for_validation(serv, deletemusic=True, verbose=False, difficulty=1): + + song = Audio(glob(serv+"/*.ogg")[0], stereo=False) + song.import_tja(glob(serv+"/*.tja")[-1], difficulty=difficulty) + song.feats = fft_and_melscale(song, nhop=512, nffts=[1024, 2048, 4096], mel_nband=80, mel_freqlo=27.5, mel_freqhi=16000.0, include_zero_cross=False) + + if deletemusic: + song.data = None + with open('./data/pickles/val_data.pickle', mode='wb') as f: + pickle.dump(song, f) + + +def music_for_train(serv, deletemusic=True, verbose=False, difficulty=0, diff=False, nhop=512, nffts=[1024, 2048, 4096], mel_nband=80, mel_freqlo=27.5, mel_freqhi=16000.0, include_zero_cross=False): + + songplaces = glob(serv) + songs = [] + + for songplace in songplaces: + + if verbose: + print(songplace) + + song = Audio(glob(songplace+"/*.ogg")[0]) + song.import_tja(glob(songplace+"/*.tja")[-1], difficulty=difficulty, diff=True) + song.data = (song.data[:, 0]+song.data[:, 1])/2 + songs.append(song) + + multi_fft_and_melscale(songs, nhop, nffts, mel_nband, mel_freqlo, mel_freqhi, include_zero_cross=include_zero_cross) + + if deletemusic: + for song in songs: + song.data = None + + with open('./data/pickles/train_data.pickle', mode='wb') as f: + pickle.dump(songs, f) + +def music_for_train_reduced(serv, deletemusic=True, verbose=False, difficulty=0, diff=False, nhop=512, nffts=[1024, 2048, 4096], mel_nband=80, mel_freqlo=27.5, mel_freqhi=16000.0, include_zero_cross=False): + + songplaces = glob(serv) + songs = [] + + for songplace in songplaces: + + if verbose: + print(songplace) + + song = Audio(glob(songplace+"/*.ogg")[0]) + song.import_tja(glob(songplace+"/*.tja")[-1], difficulty=difficulty, diff=True) + song.data = (song.data[:, 0]+song.data[:, 1])/2 + songs.append(song) + + multi_fft_and_melscale(songs, nhop, nffts, mel_nband, mel_freqlo, mel_freqhi, include_zero_cross=include_zero_cross) + + if deletemusic: + for song in songs: + song.data = None + + with open('./data/pickles/train_reduced.pickle', mode='wb') as f: + pickle.dump(songs, f) + + +def music_for_test(serv, deletemusic=True, verbose=False): + + song = Audio(glob(serv+"/*.ogg")[0], stereo=False) + # song.import_tja(glob(serv+"/*.tja")[-1]) + song.feats = fft_and_melscale(song, include_zero_cross=False) + with open('./data/pickles/test_data.pickle', mode='wb') as f: + pickle.dump(song, f) + + +if __name__ == "__main__": + + if sys.argv[1] == 'train': + print("preparing all train data processing...") + serv = "./data/train/*" + music_for_train(serv, verbose=True, difficulty=0, diff=True) + print("all train data processing done!") + + if sys.argv[1] == 'test': + print("test data proccesing...") + serv = "./data/test/" + music_for_test(serv) + print("test data processing done!") + +