diff --git a/cnn/model.py b/cnn/model.py index c72be39035582081b5aa29f3a652dcc3629b2f0c..8d3d99902e6cf74965136e0782d608ec6deb53e6 100644 --- a/cnn/model.py +++ b/cnn/model.py @@ -118,7 +118,7 @@ class convNet(nn.Module): song.answer = np.zeros((song.feats.shape[2])) - song.major_note_index = np.rint(timing[np.where(syllable != "")] * song.samplerate/512).astype(np.int32) + song.major_note_index = np.rint(timing[np.where(syllable != 0)] * song.samplerate/512).astype(np.int32) song.major_note_index = np.delete(song.major_note_index, np.where(song.major_note_index >= song.feats.shape[2])) diff --git a/cnn/music_processor.py b/cnn/music_processor.py index ce9d8571a758b42e17813c641fd35f4660f32046..c3e8bc3ec520217655ae44b8085dd1ffdaee7ed6 100644 --- a/cnn/music_processor.py +++ b/cnn/music_processor.py @@ -73,9 +73,9 @@ class Audio: for line in LINES_KARA.findall(CONTENT): lastTime = dateToTime(line[0]) for couple in RGX_TAGS.findall(line[2]): - self.timestamp.append((lastTime/100, couple[1])) + self.timestamp.append((lastTime/100, 1 if len(couple[1]) > 0 else 0)) lastTime += int(couple[0]) - self.timestamp = np.array(self.timestamp, dtype='float, object') + self.timestamp = np.array(self.timestamp, dtype='float, int') def make_frame(data, nhop, nfft): @@ -189,19 +189,19 @@ def music_for_train(serv, deletemusic=True, verbose=False, nhop=512, nffts=[1024 if verbose: print(songplace) + songname = songplace.split("/")[-1] + song = Audio(glob(songplace+"/*.ogg")[0]) song.import_ass(glob(songplace+"/*.ass")[-1]) song.data = (song.data[:, 0]+song.data[:, 1])/2 - songs.append(song) - multi_fft_and_melscale(songs, nhop, nffts, mel_nband, mel_freqlo, mel_freqhi, include_zero_cross=include_zero_cross) + song.feats = fft_and_melscale(song, nhop, nffts, mel_nband, mel_freqlo, mel_freqhi) - if deletemusic: - for song in songs: + if deletemusic: song.data = None - with open('./data/pickles/train_data.pickle', mode='wb') as f: - pickle.dump(songs, f) + with open(f'./data/pickles/{songname:s}.pickle', mode='wb') as f: + pickle.dump(song, f) def music_for_test(serv, deletemusic=True, verbose=False): diff --git a/cnn_train.py b/cnn_train.py index 8e2b0c6e88c25f3ae7308adb7f6b69fd5812f00e..b6325379525cdf6035c656e77501cca05c4de986 100644 --- a/cnn_train.py +++ b/cnn_train.py @@ -1,13 +1,19 @@ from cnn.model import * from cnn.music_processor import * +from glob import glob device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') net = convNet() net = net.to(device) - -with open('./data/pickles/train_data.pickle', mode='rb') as f: - songs = pickle.load(f) + +songplaces = glob('./data/pickles/*.pickle') +songs = [] + +for songplace in songplaces: + with open(songplace, mode='rb') as f: + song = pickle.load(f) + songs.append(song) minibatch = 128 soundlen = 15