From f036c17a606b7d1aaac438c6684feec2f4ce6ee5 Mon Sep 17 00:00:00 2001 From: Sting <loic.allegre@ensiie.fr> Date: Fri, 23 Jun 2023 10:52:15 +0200 Subject: [PATCH] Model training and syl prediction --- .gitignore | 1 + infer.py | 61 ++++++++++++++++++++++++++++++++++++ model.py | 40 +++++------------------- music_processor.py | 78 ---------------------------------------------- 4 files changed, 70 insertions(+), 110 deletions(-) create mode 100644 infer.py diff --git a/.gitignore b/.gitignore index 3679883..65e0cef 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,5 @@ !music_processor.py !model.py !process_train_data.sh +!infer.py media/ \ No newline at end of file diff --git a/infer.py b/infer.py new file mode 100644 index 0000000..bc1b6f3 --- /dev/null +++ b/infer.py @@ -0,0 +1,61 @@ +from model import * +from music_processor import * +from assUtils import AssWriter +import pickle +import numpy as np +from scipy.signal import argrelmax +from librosa.util import peak_pick +from librosa.onset import onset_detect + + +def segment(songfile): + + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + net = convNet() + net = net.to(device) + + with open('./data/pickles/train_data.pickle', mode='rb') as f: + songs = pickle.load(f) + + + if torch.cuda.is_available(): + net.load_state_dict(torch.load('./models/model.pth')) + else: + net.load_state_dict(torch.load('./models/model.pth', map_location='cpu')) + + song = songs[0] + inference = net.infer(song.feats, device, minibatch=4192) + inference = np.reshape(inference, (-1)) + + return detection(inference, song.samplerate) + + + +def detection(inference, samplerate): + + inference = smooth(inference, 5) + + + timestamp = (peak_pick(inference, pre_max=1, post_max=2, pre_avg=4, post_avg=5, delta=0.05, wait=3)) # 実際ã¯7フレーム目ã®ã¨ã“ã‚ã®éŸ³ + + timestamp = timestamp*512/samplerate + + return timestamp + + + +if __name__ == '__main__': + + onsets = segment(sys.argv[1]) + syls = [[t, ''] for t in onsets] + + print(syls) + + writer = AssWriter() + writer.openAss("./media/test.ass") + writer.writeHeader() + writer.writeSyls(syls) + writer.closeAss() + + + diff --git a/model.py b/model.py index 3715c5c..31522f3 100644 --- a/model.py +++ b/model.py @@ -97,7 +97,7 @@ class convNet(nn.Module): yield (torch.from_numpy(np.array(x)).float()) - def train(self, songs, minibatch, epoch, device, soundlen=15, val_song=None, save_place='./models/model.pth', log='./log/log.txt', don_ka=0): + def train(self, songs, minibatch, epoch, device, soundlen=15, val_song=None, save_place='./models/model.pth', log='./log/log.txt'): """ Args: songs: the list of song @@ -113,39 +113,18 @@ class convNet(nn.Module): for song in songs: - timing = song.timestamp[:, 0] - sound = song.timestamp[:, 1] + timing = np.array([syl[0] for syl in song.timestamp]) + syllable = np.array([syl[1] for syl in song.timestamp]) song.answer = np.zeros((song.feats.shape[2])) - if don_ka == 0: - song.major_note_index = np.rint(timing[np.where(sound != 0)] * song.samplerate/512).astype(np.int32) - else: - song.major_note_index = np.rint(timing[np.where(sound == don_ka)] * song.samplerate/512).astype(np.int32) - song.minor_note_index = np.rint(timing[np.where(sound == 3-don_ka)] * song.samplerate/512).astype(np.int32) + + song.major_note_index = np.rint(timing[np.where(syllable != "")] * song.samplerate/512).astype(np.int32) song.major_note_index = np.delete(song.major_note_index, np.where(song.major_note_index >= song.feats.shape[2])) - song.minor_note_index = np.delete(song.minor_note_index, np.where(song.minor_note_index >= song.feats.shape[2])) - song.answer[song.major_note_index] = 1 - song.answer[song.minor_note_index] = 0.26 - song.answer = milden(song.answer) - - if val_song: - timing = val_song.timestamp[:, 0] - sound = val_song.timestamp[:, 1] - val_song.answer = np.zeros((val_song.feats.shape[2])) - - if don_ka == 0: - val_song.major_note_index = np.rint(timing[np.where(sound != 0)] * val_song.samplerate/512).astype(np.int32) - else: - val_song.major_note_index = np.rint(timing[np.where(sound == don_ka)] * val_song.samplerate/512).astype(np.int32) - val_song.minor_note_index = np.rint(timing[np.where(sound == 3-don_ka)] * val_song.samplerate/512).astype(np.int32) + song.answer[song.major_note_index] = 1 - val_song.major_note_index = np.delete(val_song.major_note_index, np.where(val_song.major_note_index >= val_song.feats.shape[2])) - val_song.minor_note_index = np.delete(val_song.minor_note_index, np.where(val_song.minor_note_index >= val_song.feats.shape[2])) - val_song.answer[val_song.major_note_index] = 1 - val_song.answer[val_song.minor_note_index] = 0.26 - val_song.answer = milden(val_song.answer) + song.answer = milden(song.answer) # training optimizer = optim.SGD(self.parameters(), lr=0.02) @@ -210,8 +189,5 @@ if __name__ == '__main__': soundlen = 15 epoch = 100 - if sys.argv[1] == 'don': - net.train(songs=songs, minibatch=minibatch, val_song=None, epoch=epoch, device=device, soundlen=soundlen, save_place='./models/don_model.pth', log='./data/log/don.txt', don_ka=1) - if sys.argv[1] == 'ka': - net.train(songs=songs, minibatch=minibatch, val_song=None, epoch=epoch, device=device, soundlen=soundlen, save_place='./models/ka_model.pth', log='./data/log/ka.txt', don_ka=2) \ No newline at end of file + net.train(songs=songs, minibatch=minibatch, val_song=None, epoch=epoch, device=device, soundlen=soundlen, save_place='./models/model.pth', log='./data/log/log.txt') \ No newline at end of file diff --git a/music_processor.py b/music_processor.py index 78890e8..ce9d857 100644 --- a/music_processor.py +++ b/music_processor.py @@ -57,84 +57,6 @@ class Audio: sf.write(filename, self.data[start_t:stop_t], self.samplerate) - def import_tja(self, filename, verbose=False, diff=False, difficulty=None): - """imports tja file and convert it into timestamp""" - - now = 0.0 - bpm = 100 - measure = [4, 4] # hyousi - self.timestamp = [] - skipflag = False - - with open(filename, "rb") as f: - while True: - line = f.readline() - try: - line = line.decode('sjis') - except UnicodeDecodeError: - line = line.decode('utf-8') - if line.find('//') != -1: - line = line[:line.find('//')] - if line[0:5] == "TITLE": - if verbose: - print("importing: ", line[6:]) - elif line[0:6] == "OFFSET": - now = -float(line[7:-2]) - elif line[0:4] == "BPM:": - bpm = float(line[4:-2]) - if line[0:6] == "COURSE": - if difficulty and difficulty > 0: - skipflag = True - difficulty -= 1 - elif line == "#START\r\n": - if skipflag: - skipflag = False - continue - break - - sound = [] - while True: - line = f.readline() - # print(line) - try: - line = line.decode('sjis') - except UnicodeDecodeError: - line = line.decode('utf-8') - - if line.find('//') != -1: - line = line[:line.find('//')] - if line[0] <= '9' and line[0] >= '0': - if line.find(',') != -1: - sound += line[0:line.find(',')] - beat = len(sound) - for i in range(beat): - if diff: - if int(sound[i]) in (1, 3, 5, 6, 7): - self.timestamp.append([int(100*(now+i*60*measure[0]/bpm/beat))/100, 1]) - elif int(sound[i]) in (2, 4): - self.timestamp.append([int(100*(now+i*60*measure[0]/bpm/beat))/100, 2]) - else: - if int(sound[i]) != 0: - self.timestamp.append([int(100*(now+i*60*measure[0]/bpm/beat))/100, int(sound[i])]) - now += 60/bpm*measure[0] - sound = [] - else: - sound += line[0:-2] - elif line[0] == ',': - now += 60/bpm*measure[0] - elif line[0:10] == '#BPMCHANGE': - bpm = float(line[11:-2]) - elif line[0:8] == '#MEASURE': - measure[0] = int(line[line.find('/')-1]) - measure[1] = int(line[line.find('/')+1]) - elif line[0:6] == '#DELAY': - now += float(line[7:-2]) - elif line[0:4] == "#END": - if(verbose): - print("import complete!") - self.timestamp = np.array(self.timestamp) - break - def import_ass(self, filename): -- GitLab