From 342d47e50ed33fff87188204fc74d8c82e52e044 Mon Sep 17 00:00:00 2001
From: Sting <loic.allegre@ensiie.fr>
Date: Fri, 23 Jun 2023 11:46:30 +0200
Subject: [PATCH] Cleanup and reorganization

---
 .gitignore                                   | 10 ++++++----
 autokara.py                                  |  5 ++---
 model.py => cnn/model.py                     |  2 +-
 music_processor.py => cnn/music_processor.py |  0
 infer.py => cnn/segment.py                   |  4 ++--
 cnn_prepare_data.py                          | 15 +++++++++++++++
 cnn_train.py                                 | 17 +++++++++++++++++
 segment.py => rosa/segment.py                |  0
 8 files changed, 43 insertions(+), 10 deletions(-)
 rename model.py => cnn/model.py (99%)
 rename music_processor.py => cnn/music_processor.py (100%)
 rename infer.py => cnn/segment.py (95%)
 create mode 100644 cnn_prepare_data.py
 create mode 100644 cnn_train.py
 rename segment.py => rosa/segment.py (100%)

diff --git a/.gitignore b/.gitignore
index 65e0cef..0b531c7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -6,10 +6,12 @@
 !extractAss.sh
 !karaUtils.py
 !autokara.py
-!segment.py
 !assUtils.py
-!music_processor.py
-!model.py
 !process_train_data.sh
-!infer.py
+!cnn_prepare_data.py
+!cnn_train.py
+!*/cnn/segment.py
+!*/cnn/music_processor.py
+!*/cnn/model.py
+!rosa/*.py
 media/
\ No newline at end of file
diff --git a/autokara.py b/autokara.py
index 4f672fb..60a67f1 100644
--- a/autokara.py
+++ b/autokara.py
@@ -6,8 +6,7 @@ import shlex
 from pathlib import Path
 from assUtils import AssWriter
 
-from segment import Segment
-import infer
+from cnn.segment import segment
 
 
 parser = argparse.ArgumentParser(description='AutoKara - Automatic karaoke timing tool')
@@ -42,7 +41,7 @@ else:
 
 
 print("Identifying syl starts...")
-onsets = infer.segment(sys.argv[1])
+onsets = segment(sys.argv[1])
 syls = [[t, ''] for t in onsets]
 
 print("Syls found, writing ASS file...")
diff --git a/model.py b/cnn/model.py
similarity index 99%
rename from model.py
rename to cnn/model.py
index 31522f3..c72be39 100644
--- a/model.py
+++ b/cnn/model.py
@@ -4,7 +4,7 @@ import torch.nn.functional as F
 import torch.optim as optim
 import numpy as np
 from tqdm import tqdm
-from music_processor import *
+from cnn.music_processor import *
 
 """
 On the paper,
diff --git a/music_processor.py b/cnn/music_processor.py
similarity index 100%
rename from music_processor.py
rename to cnn/music_processor.py
diff --git a/infer.py b/cnn/segment.py
similarity index 95%
rename from infer.py
rename to cnn/segment.py
index e983b45..13f293c 100644
--- a/infer.py
+++ b/cnn/segment.py
@@ -1,5 +1,5 @@
-from model import *
-from music_processor import *
+from cnn.model import *
+from cnn.music_processor import *
 from assUtils import AssWriter
 import pickle
 import numpy as np
diff --git a/cnn_prepare_data.py b/cnn_prepare_data.py
new file mode 100644
index 0000000..81b96fa
--- /dev/null
+++ b/cnn_prepare_data.py
@@ -0,0 +1,15 @@
+from cnn.music_processor import *
+
+
+
+if sys.argv[1] == 'train':
+    print("preparing all train data processing...")
+    serv = f'./{sys.argv[2]:s}/*'
+    music_for_train(serv, verbose=True)
+    print("all train data processing done!")    
+
+if sys.argv[1] == 'test':
+    print("test data proccesing...")
+    serv = f'./{sys.argv[2]:s}/*'
+    music_for_test(serv)
+    print("test data processing done!")
\ No newline at end of file
diff --git a/cnn_train.py b/cnn_train.py
new file mode 100644
index 0000000..8e2b0c6
--- /dev/null
+++ b/cnn_train.py
@@ -0,0 +1,17 @@
+from cnn.model import *
+from cnn.music_processor import *
+
+
+device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+net = convNet()
+net = net.to(device)
+        
+with open('./data/pickles/train_data.pickle', mode='rb') as f:
+    songs = pickle.load(f)
+
+minibatch = 128
+soundlen = 15
+epoch = 100
+
+
+net.train(songs=songs, minibatch=minibatch, val_song=None, epoch=epoch, device=device, soundlen=soundlen, save_place='./models/model.pth', log='./data/log/log.txt')
\ No newline at end of file
diff --git a/segment.py b/rosa/segment.py
similarity index 100%
rename from segment.py
rename to rosa/segment.py
-- 
GitLab