From f036c17a606b7d1aaac438c6684feec2f4ce6ee5 Mon Sep 17 00:00:00 2001
From: Sting <loic.allegre@ensiie.fr>
Date: Fri, 23 Jun 2023 10:52:15 +0200
Subject: [PATCH] Model training and syl prediction

---
 .gitignore         |  1 +
 infer.py           | 61 ++++++++++++++++++++++++++++++++++++
 model.py           | 40 +++++-------------------
 music_processor.py | 78 ----------------------------------------------
 4 files changed, 70 insertions(+), 110 deletions(-)
 create mode 100644 infer.py

diff --git a/.gitignore b/.gitignore
index 3679883..65e0cef 100644
--- a/.gitignore
+++ b/.gitignore
@@ -11,4 +11,5 @@
 !music_processor.py
 !model.py
 !process_train_data.sh
+!infer.py
 media/
\ No newline at end of file
diff --git a/infer.py b/infer.py
new file mode 100644
index 0000000..bc1b6f3
--- /dev/null
+++ b/infer.py
@@ -0,0 +1,61 @@
+from model import *
+from music_processor import *
+from assUtils import AssWriter
+import pickle
+import numpy as np
+from scipy.signal import argrelmax
+from librosa.util import peak_pick
+from librosa.onset import onset_detect
+
+
+def segment(songfile):
+
+    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+    net = convNet()
+    net = net.to(device)
+
+    with open('./data/pickles/train_data.pickle', mode='rb') as f:
+        songs = pickle.load(f)
+
+        
+    if torch.cuda.is_available():
+        net.load_state_dict(torch.load('./models/model.pth'))
+    else:
+        net.load_state_dict(torch.load('./models/model.pth', map_location='cpu'))
+
+    song = songs[0]
+    inference = net.infer(song.feats, device, minibatch=4192)
+    inference = np.reshape(inference, (-1))
+
+    return detection(inference, song.samplerate)
+
+
+
+def detection(inference, samplerate):
+
+    inference = smooth(inference, 5)
+   
+
+    timestamp = (peak_pick(inference, pre_max=1, post_max=2, pre_avg=4, post_avg=5, delta=0.05, wait=3))  # 実際は7フレーム目のところの音
+
+    timestamp = timestamp*512/samplerate
+ 
+    return timestamp
+
+
+
+if __name__ == '__main__':
+
+    onsets = segment(sys.argv[1])
+    syls = [[t, ''] for t in onsets]
+
+    print(syls)
+
+    writer = AssWriter()
+    writer.openAss("./media/test.ass")
+    writer.writeHeader()
+    writer.writeSyls(syls)
+    writer.closeAss()
+
+    
+
diff --git a/model.py b/model.py
index 3715c5c..31522f3 100644
--- a/model.py
+++ b/model.py
@@ -97,7 +97,7 @@ class convNet(nn.Module):
             yield (torch.from_numpy(np.array(x)).float())
 
 
-    def train(self, songs, minibatch, epoch, device, soundlen=15, val_song=None, save_place='./models/model.pth', log='./log/log.txt', don_ka=0):
+    def train(self, songs, minibatch, epoch, device, soundlen=15, val_song=None, save_place='./models/model.pth', log='./log/log.txt'):
         """
         Args:
             songs: the list of song
@@ -113,39 +113,18 @@ class convNet(nn.Module):
 
         for song in songs:
             
-            timing = song.timestamp[:, 0]
-            sound  = song.timestamp[:, 1]
+            timing = np.array([syl[0] for syl in song.timestamp])
+            syllable  = np.array([syl[1] for syl in song.timestamp])
             song.answer = np.zeros((song.feats.shape[2]))
 
-            if don_ka == 0:
-                song.major_note_index = np.rint(timing[np.where(sound != 0)] * song.samplerate/512).astype(np.int32)
-            else:
-                song.major_note_index = np.rint(timing[np.where(sound == don_ka)] * song.samplerate/512).astype(np.int32)
-                song.minor_note_index = np.rint(timing[np.where(sound == 3-don_ka)] * song.samplerate/512).astype(np.int32)
+           
+            song.major_note_index = np.rint(timing[np.where(syllable != "")] * song.samplerate/512).astype(np.int32)
             
             song.major_note_index = np.delete(song.major_note_index, np.where(song.major_note_index >= song.feats.shape[2]))
-            song.minor_note_index = np.delete(song.minor_note_index, np.where(song.minor_note_index >= song.feats.shape[2]))
-            song.answer[song.major_note_index] = 1
-            song.answer[song.minor_note_index] = 0.26
-            song.answer = milden(song.answer)
-
-        if val_song:
 
-            timing = val_song.timestamp[:, 0]
-            sound = val_song.timestamp[:, 1]
-            val_song.answer = np.zeros((val_song.feats.shape[2]))
-
-            if don_ka == 0:
-                val_song.major_note_index = np.rint(timing[np.where(sound != 0)] * val_song.samplerate/512).astype(np.int32)
-            else:
-                val_song.major_note_index = np.rint(timing[np.where(sound == don_ka)] * val_song.samplerate/512).astype(np.int32)
-                val_song.minor_note_index = np.rint(timing[np.where(sound == 3-don_ka)] * val_song.samplerate/512).astype(np.int32)
+            song.answer[song.major_note_index] = 1
 
-            val_song.major_note_index = np.delete(val_song.major_note_index, np.where(val_song.major_note_index >= val_song.feats.shape[2]))
-            val_song.minor_note_index = np.delete(val_song.minor_note_index, np.where(val_song.minor_note_index >= val_song.feats.shape[2]))
-            val_song.answer[val_song.major_note_index] = 1
-            val_song.answer[val_song.minor_note_index] = 0.26
-            val_song.answer = milden(val_song.answer)
+            song.answer = milden(song.answer)
 
         # training
         optimizer = optim.SGD(self.parameters(), lr=0.02)
@@ -210,8 +189,5 @@ if __name__ == '__main__':
     soundlen = 15
     epoch = 100
 
-    if sys.argv[1] == 'don':
-        net.train(songs=songs, minibatch=minibatch, val_song=None, epoch=epoch, device=device, soundlen=soundlen, save_place='./models/don_model.pth', log='./data/log/don.txt', don_ka=1)
     
-    if sys.argv[1] == 'ka':
-        net.train(songs=songs, minibatch=minibatch, val_song=None, epoch=epoch, device=device, soundlen=soundlen, save_place='./models/ka_model.pth', log='./data/log/ka.txt', don_ka=2)
\ No newline at end of file
+    net.train(songs=songs, minibatch=minibatch, val_song=None, epoch=epoch, device=device, soundlen=soundlen, save_place='./models/model.pth', log='./data/log/log.txt')
\ No newline at end of file
diff --git a/music_processor.py b/music_processor.py
index 78890e8..ce9d857 100644
--- a/music_processor.py
+++ b/music_processor.py
@@ -57,84 +57,6 @@ class Audio:
         sf.write(filename, self.data[start_t:stop_t], self.samplerate)
 
 
-    def import_tja(self, filename, verbose=False, diff=False, difficulty=None):
-        """imports tja file and convert it into timestamp"""
-        
-        now = 0.0
-        bpm = 100
-        measure = [4, 4]  # hyousi
-        self.timestamp = []
-        skipflag = False
-
-        with open(filename, "rb") as f:
-            while True:
-                line = f.readline()
-                try:
-                    line = line.decode('sjis')
-                except UnicodeDecodeError:
-                    line = line.decode('utf-8')
-                if line.find('//') != -1:
-                    line = line[:line.find('//')]
-                if line[0:5] == "TITLE":
-                    if verbose:
-                        print("importing: ", line[6:])
-                elif line[0:6] == "OFFSET":
-                    now = -float(line[7:-2])
-                elif line[0:4] == "BPM:":
-                    bpm = float(line[4:-2])
-                if line[0:6] == "COURSE":
-                    if difficulty and difficulty > 0:
-                        skipflag = True
-                        difficulty -= 1
-                elif line == "#START\r\n":
-                    if skipflag:
-                        skipflag = False
-                        continue
-                    break
-            
-            sound = []
-            while True:
-                line = f.readline()
-                # print(line)
-                try:
-                    line = line.decode('sjis')
-                except UnicodeDecodeError:
-                    line = line.decode('utf-8')
-
-                if line.find('//') != -1:
-                    line = line[:line.find('//')]
-                if line[0] <= '9' and line[0] >= '0':
-                    if line.find(',') != -1:
-                        sound += line[0:line.find(',')]
-                        beat = len(sound)
-                        for i in range(beat):
-                            if diff:
-                                if int(sound[i]) in (1, 3, 5, 6, 7):
-                                    self.timestamp.append([int(100*(now+i*60*measure[0]/bpm/beat))/100, 1])
-                                elif int(sound[i]) in (2, 4):
-                                    self.timestamp.append([int(100*(now+i*60*measure[0]/bpm/beat))/100, 2])
-                            else:
-                                if int(sound[i]) != 0:
-                                    self.timestamp.append([int(100*(now+i*60*measure[0]/bpm/beat))/100, int(sound[i])])
-                        now += 60/bpm*measure[0]
-                        sound = []
-                    else:
-                        sound += line[0:-2]
-                elif line[0] == ',':
-                    now += 60/bpm*measure[0]
-                elif line[0:10] == '#BPMCHANGE':
-                    bpm = float(line[11:-2])
-                elif line[0:8] == '#MEASURE':
-                    measure[0] = int(line[line.find('/')-1])
-                    measure[1] = int(line[line.find('/')+1])
-                elif line[0:6] == '#DELAY':
-                    now += float(line[7:-2])
-                elif line[0:4] == "#END":
-                    if(verbose):
-                        print("import complete!")
-                    self.timestamp = np.array(self.timestamp)
-                    break
-
 
 
     def import_ass(self, filename):
-- 
GitLab