diff --git a/.gitignore b/.gitignore
index d5283ab6a3cd51595be63e66e80e56a621715c5f..1504b226c7ba1d4903cb5fca38fc9edf28e7c04f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,4 +7,6 @@
 !autokara.py
 !segment.py
 !assUtils.py
+!music_processor.py
+!model.py
 media/
\ No newline at end of file
diff --git a/assUtils.py b/assUtils.py
index 610081f7cf6cafd1d836c8e4e7c1d994d133ed8d..510e639f8c31ea0d41365f84233cf7abffabe374 100644
--- a/assUtils.py
+++ b/assUtils.py
@@ -11,6 +11,15 @@ def timeToDate(time):
     return f'{hours:02d}:{remainder_mins:02d}:{remainder_sec:.2f}'
 
 
+def dateToTime(date):
+    """
+    The `date` should be in the following format: H:MM:SS.cs
+    """
+    hourInMinuts = int(date[0:1]) * 60
+    minutsInSeconds = (int(date[2:4]) + hourInMinuts) * 60
+    secondsInCentiseconds = (int(date[5:7]) + minutsInSeconds) * 100
+    return int(date[8:10]) + secondsInCentiseconds
+
 
 class AssWriter:
 
@@ -44,12 +53,12 @@ Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text
 
     def writeSyls(self, syl_timings):
         last_syl_dur = 500
-        start_time = timeToDate(syl_timings[0])
-        end_time = timeToDate(syl_timings[-1] + last_syl_dur//100)
+        start_time = timeToDate(syl_timings[0][0])
+        end_time = timeToDate(syl_timings[-1][0] + last_syl_dur//100)
         line = f'Dialogue: 0,{start_time},{end_time},Default,,0,0,0,,'
         for i in range(len(syl_timings) - 1):
-            syl_dur = round((syl_timings[i+1] - syl_timings[i]) * 100)
-            line += f'{{\k{syl_dur:d}}}'
-        line += f'{{\k{last_syl_dur:d}}}\n'
+            syl_dur = round((syl_timings[i+1][0] - syl_timings[i][0]) * 100)
+            line += f'{{\k{syl_dur:d}}}{syl_timings[i][1]:s}'
+        line += f'{{\k{last_syl_dur:d}}}{syl_timings[-1][1]:s}\n'
 
         self.file.write(line)
\ No newline at end of file
diff --git a/karaUtils.py b/karaUtils.py
deleted file mode 100755
index c47d07e1e3d5d1ce4fbd8a70f49e44a05487506c..0000000000000000000000000000000000000000
--- a/karaUtils.py
+++ /dev/null
@@ -1,34 +0,0 @@
-#!/usr/bin/env python3
-import re
-import sys
-
-try:
-    FILE = sys.argv[1]
-except IndexError:
-    print("usage : %s inputFile.py" % sys.argv[0])
-
-with open(FILE, 'r') as f:
-    CONTENT = f.read()
-
-LINES_KARA = re.compile(r"Comment:.*(\d+:\d{2}:\d{2}.\d{2}),(\d+:\d{2}:\d{2}.\d{2}),.*,karaoke,(.*)\n");
-
-RGX_TAGS = re.compile(r"\{\\k(\d+)\}([^\{\n\r]*)")
-
-SYLS = []
-
-def dateToTime(date):
-    """
-    The `date` should be in the following format: HH:MM:SS.cs
-    """
-    hourInMinuts = int(date[0:1]) * 60
-    minutsInSeconds = (int(date[3:4]) + hourInMinuts) * 60
-    secondsInCentiseconds = (int(date[6:7]) + minutsInSeconds) * 100
-    return int(date[9:10]) + secondsInCentiseconds
-
-for line in LINES_KARA.findall(CONTENT):
-    lastTime = dateToTime(line[0])
-    for couple in RGX_TAGS.findall(line[2]):
-        SYLS.append((lastTime, couple[1], couple[0]))
-        lastTime += int(couple[0])
-
-print(SYLS)
diff --git a/model.py b/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..3715c5cc1dbae6821e1351362ac55736dbb64e95
--- /dev/null
+++ b/model.py
@@ -0,0 +1,217 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+import numpy as np
+from tqdm import tqdm
+from music_processor import *
+
+"""
+On the paper,
+Starting from a stack of three spectrogram excerpts,
+convolution and max-pooling in turns compute a set of 20 feature maps 
+classified with a fully-connected network.
+"""
+
+class convNet(nn.Module):
+    """
+    copies the neural net used in a paper.
+    "Improved musical onset detection with Convolutional Neural Networks".
+    src: https://ieeexplore.ieee.org/document/6854953
+    """
+
+    def __init__(self):
+
+        super(convNet, self).__init__()
+        # model
+        self.conv1 = nn.Conv2d(3, 10, (3, 7))
+        self.conv2 = nn.Conv2d(10, 20, 3)
+        self.fc1 = nn.Linear(1120, 256)
+        self.fc2 = nn.Linear(256, 120)
+        self.fc3 = nn.Linear(120, 1)
+
+
+    def forward(self, x, istraining=False, minibatch=1):
+
+        x = F.max_pool2d(F.relu(self.conv1(x)), (3, 1))
+        x = F.max_pool2d(F.relu(self.conv2(x)), (3, 1))
+        x = F.dropout(x.view(minibatch, -1), training=istraining)
+        x = F.dropout(F.relu(self.fc1(x)), training=istraining)
+        x = F.dropout(F.relu(self.fc2(x)), training=istraining)
+
+        return F.sigmoid(self.fc3(x))
+
+
+    def train_data_builder(self, feats, answer, major_note_index, samplerate, soundlen=15, minibatch=1, split=0.2):
+        """
+        Args:
+            feats: song.feats; Audio module
+            answers: song.answers; Audio module
+            major_note_index: answer labels; corresponding to feats
+            samplerate: song.samplerate; Audio module
+            soundlen: =15. 学習モデルに渡す画像データの横方向の長さ.ここでは(80 * 15)サイズのデータを使用している
+            minibatch: training minibatch
+            split: =1. 
+        Variables:
+            minspace: minimum space between major note indexs
+            maxspace: maximum space between major note indexs
+            idx: index of major_note_index or feats
+            dist: distance of two notes
+        """
+
+        # acceptable interval in seconds
+        minspace = 0.1
+        maxspace = 0.7
+
+        idx = np.random.permutation(major_note_index.shape[0] - soundlen) + soundlen // 2
+        X, y = [], []
+        cnt = 0
+        
+        for i in range(int(idx.shape[0] * split)):
+
+            dist = major_note_index[idx[i] + 1] - major_note_index[idx[i]]  # distinguish by this value
+            
+            if dist < maxspace * samplerate / 512 and dist > minspace * samplerate / 512:    
+                for j in range(-1, dist + 2):
+                    X.append(feats[:, :, major_note_index[idx[i]] - soundlen // 2 + j : major_note_index[idx[i]] + soundlen // 2 + j + 1])
+                    y.append(answer[major_note_index[idx[i]] + j])
+                    cnt += 1
+                    
+                    if cnt % minibatch == 0:
+                        yield (torch.from_numpy(np.array(X)).float(), torch.from_numpy(np.array(y)).float())
+                        X, y = [], []
+
+
+    def infer_data_builder(self, feats, soundlen=15, minibatch=1):
+        
+        x = []
+        
+        for i in range(feats.shape[2] - soundlen):
+            x.append(feats[:, :, i:i+soundlen])
+        
+            if (i + 1) % minibatch == 0:
+                yield (torch.from_numpy(np.array(x)).float())
+                x = []
+        
+        if len(x) != 0:
+            yield (torch.from_numpy(np.array(x)).float())
+
+
+    def train(self, songs, minibatch, epoch, device, soundlen=15, val_song=None, save_place='./models/model.pth', log='./log/log.txt', don_ka=0):
+        """
+        Args:
+            songs: the list of song
+            minibatch: minibatch value
+            epoch: number of train 
+            device: cpu / gpu
+            soundlen: width of one train data's image
+            val_song: validation song, if you wanna validation while training, give a path of validation song data.
+            save_place: save place path
+            log: log place path
+            don-ka: don(1) or ka(2) or both(0), usually, firstly, train don, then, train ka.
+        """
+
+        for song in songs:
+            
+            timing = song.timestamp[:, 0]
+            sound  = song.timestamp[:, 1]
+            song.answer = np.zeros((song.feats.shape[2]))
+
+            if don_ka == 0:
+                song.major_note_index = np.rint(timing[np.where(sound != 0)] * song.samplerate/512).astype(np.int32)
+            else:
+                song.major_note_index = np.rint(timing[np.where(sound == don_ka)] * song.samplerate/512).astype(np.int32)
+                song.minor_note_index = np.rint(timing[np.where(sound == 3-don_ka)] * song.samplerate/512).astype(np.int32)
+            
+            song.major_note_index = np.delete(song.major_note_index, np.where(song.major_note_index >= song.feats.shape[2]))
+            song.minor_note_index = np.delete(song.minor_note_index, np.where(song.minor_note_index >= song.feats.shape[2]))
+            song.answer[song.major_note_index] = 1
+            song.answer[song.minor_note_index] = 0.26
+            song.answer = milden(song.answer)
+
+        if val_song:
+
+            timing = val_song.timestamp[:, 0]
+            sound = val_song.timestamp[:, 1]
+            val_song.answer = np.zeros((val_song.feats.shape[2]))
+
+            if don_ka == 0:
+                val_song.major_note_index = np.rint(timing[np.where(sound != 0)] * val_song.samplerate/512).astype(np.int32)
+            else:
+                val_song.major_note_index = np.rint(timing[np.where(sound == don_ka)] * val_song.samplerate/512).astype(np.int32)
+                val_song.minor_note_index = np.rint(timing[np.where(sound == 3-don_ka)] * val_song.samplerate/512).astype(np.int32)
+
+            val_song.major_note_index = np.delete(val_song.major_note_index, np.where(val_song.major_note_index >= val_song.feats.shape[2]))
+            val_song.minor_note_index = np.delete(val_song.minor_note_index, np.where(val_song.minor_note_index >= val_song.feats.shape[2]))
+            val_song.answer[val_song.major_note_index] = 1
+            val_song.answer[val_song.minor_note_index] = 0.26
+            val_song.answer = milden(val_song.answer)
+
+        # training
+        optimizer = optim.SGD(self.parameters(), lr=0.02)
+        criterion = nn.MSELoss()
+        running_loss = 0
+        val_loss = 0
+
+        for i in range(epoch):
+            for song in songs:
+                for X, y in self.train_data_builder(song.feats, song.answer, song.major_note_index, song.samplerate, soundlen, minibatch, split=0.2):
+                    optimizer.zero_grad()
+                    output = self(X.to(device), istraining=True, minibatch=minibatch)
+                    target = y.to(device)
+                    loss = criterion(output.squeeze(), target)
+                    loss.backward()
+                    optimizer.step()
+                    running_loss += loss.data.item()
+
+            with open(log, 'a') as f:
+                print("epoch: %.d running_loss: %.10f " % (i+1, running_loss), file=f)
+
+            print("epoch: %.d running_loss: %.10f" % (i+1, running_loss))
+            
+            running_loss = 0    
+
+            if val_song:
+                inference = torch.from_numpy(self.infer(val_song.feats, device, minibatch=512)).to(device)
+                target = torch.from_numpy(val_song.answer[:-soundlen]).float().to(device)
+                loss = criterion(inference.squeeze(), target)
+                val_loss = loss.data.item()
+
+                with open(log, 'a') as f:
+                    print("val_loss: %.10f " % (val_loss), file=f)
+
+        torch.save(self.state_dict(), save_place)
+
+
+    def infer(self, feats, device, minibatch=1):
+
+        with torch.no_grad():
+            inference = None
+            for x in tqdm(self.infer_data_builder(feats, minibatch=minibatch), total=feats.shape[2]//minibatch):
+                output = self(x.to(device), minibatch=x.shape[0])
+                if inference is not None:
+                    inference = np.concatenate((inference, output.cpu().numpy().reshape(-1)))
+                else:
+                    inference = output.cpu().numpy().reshape(-1)
+            
+            return np.array(inference).reshape(-1)
+
+
+if __name__ == '__main__':
+
+    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+    net = convNet()
+    net = net.to(device)
+           
+    with open('./data/pickles/train_data.pickle', mode='rb') as f:
+        songs = pickle.load(f)
+
+    minibatch = 128
+    soundlen = 15
+    epoch = 100
+
+    if sys.argv[1] == 'don':
+        net.train(songs=songs, minibatch=minibatch, val_song=None, epoch=epoch, device=device, soundlen=soundlen, save_place='./models/don_model.pth', log='./data/log/don.txt', don_ka=1)
+    
+    if sys.argv[1] == 'ka':
+        net.train(songs=songs, minibatch=minibatch, val_song=None, epoch=epoch, device=device, soundlen=soundlen, save_place='./models/ka_model.pth', log='./data/log/ka.txt', don_ka=2)
\ No newline at end of file
diff --git a/music_processor.py b/music_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ba18d96187ee07403705e4ce008368113241b4a
--- /dev/null
+++ b/music_processor.py
@@ -0,0 +1,345 @@
+import soundfile as sf
+import matplotlib.pyplot as plt
+import numpy as np
+import os
+from glob import glob
+from scipy import signal
+from scipy.fftpack import fft
+from librosa.filters import mel
+from librosa.display import specshow
+from librosa import stft
+from librosa.effects import pitch_shift
+import pickle
+import sys
+from numba import jit, prange
+from sklearn.preprocessing import normalize
+import re
+from assUtils import dateToTime, timeToDate
+
+
+class Audio:
+    """
+    audio class which holds music data and timestamp for notes.
+    Args:
+        filename: file name.
+        stereo: True or False; wether you have Don/Ka streo file or not. normaly True.
+    Variables:
+    Example:
+        >>>from music_processor import *
+        >>>song = Audio(filename)
+        >>># to get audio data
+        >>>song.data
+        >>># to import .tja files:
+        >>>song.import_tja(filename)
+        >>># to get data converted
+        >>>song.data = (song.data[:,0]+song.data[:,1])/2
+        >>>fft_and_melscale(song, include_zero_cross=False)
+    """
+
+    def __init__(self, filename, stereo=True):
+
+        self.data, self.samplerate = sf.read(filename, always_2d=True)
+        if stereo is False:
+            self.data = (self.data[:, 0]+self.data[:, 1])/2
+        self.timestamp = []
+
+
+    def plotaudio(self, start_t, stop_t):
+
+        plt.plot(np.linspace(start_t, stop_t, stop_t-start_t), self.data[start_t:stop_t, 0])
+        plt.show()
+
+
+    def save(self, filename="./savedmusic.wav", start_t=0, stop_t=None):
+
+        if stop_t is None:
+            stop_t = self.data.shape[0]
+        sf.write(filename, self.data[start_t:stop_t], self.samplerate)
+
+
+    def import_tja(self, filename, verbose=False, diff=False, difficulty=None):
+        """imports tja file and convert it into timestamp"""
+        
+        now = 0.0
+        bpm = 100
+        measure = [4, 4]  # hyousi
+        self.timestamp = []
+        skipflag = False
+
+        with open(filename, "rb") as f:
+            while True:
+                line = f.readline()
+                try:
+                    line = line.decode('sjis')
+                except UnicodeDecodeError:
+                    line = line.decode('utf-8')
+                if line.find('//') != -1:
+                    line = line[:line.find('//')]
+                if line[0:5] == "TITLE":
+                    if verbose:
+                        print("importing: ", line[6:])
+                elif line[0:6] == "OFFSET":
+                    now = -float(line[7:-2])
+                elif line[0:4] == "BPM:":
+                    bpm = float(line[4:-2])
+                if line[0:6] == "COURSE":
+                    if difficulty and difficulty > 0:
+                        skipflag = True
+                        difficulty -= 1
+                elif line == "#START\r\n":
+                    if skipflag:
+                        skipflag = False
+                        continue
+                    break
+            
+            sound = []
+            while True:
+                line = f.readline()
+                # print(line)
+                try:
+                    line = line.decode('sjis')
+                except UnicodeDecodeError:
+                    line = line.decode('utf-8')
+
+                if line.find('//') != -1:
+                    line = line[:line.find('//')]
+                if line[0] <= '9' and line[0] >= '0':
+                    if line.find(',') != -1:
+                        sound += line[0:line.find(',')]
+                        beat = len(sound)
+                        for i in range(beat):
+                            if diff:
+                                if int(sound[i]) in (1, 3, 5, 6, 7):
+                                    self.timestamp.append([int(100*(now+i*60*measure[0]/bpm/beat))/100, 1])
+                                elif int(sound[i]) in (2, 4):
+                                    self.timestamp.append([int(100*(now+i*60*measure[0]/bpm/beat))/100, 2])
+                            else:
+                                if int(sound[i]) != 0:
+                                    self.timestamp.append([int(100*(now+i*60*measure[0]/bpm/beat))/100, int(sound[i])])
+                        now += 60/bpm*measure[0]
+                        sound = []
+                    else:
+                        sound += line[0:-2]
+                elif line[0] == ',':
+                    now += 60/bpm*measure[0]
+                elif line[0:10] == '#BPMCHANGE':
+                    bpm = float(line[11:-2])
+                elif line[0:8] == '#MEASURE':
+                    measure[0] = int(line[line.find('/')-1])
+                    measure[1] = int(line[line.find('/')+1])
+                elif line[0:6] == '#DELAY':
+                    now += float(line[7:-2])
+                elif line[0:4] == "#END":
+                    if(verbose):
+                        print("import complete!")
+                    self.timestamp = np.array(self.timestamp)
+                    break
+
+
+
+    def import_ass(self, filename):
+
+        with open(filename, 'r') as f:
+            CONTENT = f.read()
+
+        LINES_KARA = re.compile(r"Comment:.*(\d+:\d{2}:\d{2}.\d{2}),(\d+:\d{2}:\d{2}.\d{2}),.*,karaoke,(.*)\n");
+
+        RGX_TAGS = re.compile(r"\{\\k(\d+)\}([^\{\n\r]*)")
+
+        SYLS = []
+
+        for line in LINES_KARA.findall(CONTENT):
+            lastTime = dateToTime(line[0])
+            for couple in RGX_TAGS.findall(line[2]):
+                self.timestamp.append((lastTime/100, couple[1]))
+                lastTime += int(couple[0])
+        print(type(lastTime))
+        print(self.timestamp)
+        self.timestamp = np.array(self.timestamp, dtype='float, object')
+
+
+def make_frame(data, nhop, nfft):
+    """
+    helping function for fftandmelscale.
+    細かい時間に切り分けたものを学習データとするため,nhop(512)ずつずらしながらnfftサイズのデータを配列として返す
+    """
+    
+    length = data.shape[0]
+    framedata = np.concatenate((data, np.zeros(nfft)))  # zero padding
+    return np.array([framedata[i*nhop:i*nhop+nfft] for i in range(length//nhop)])  
+
+
+@jit
+def fft_and_melscale(song, nhop=512, nffts=[1024, 2048, 4096], mel_nband=80, mel_freqlo=27.5, mel_freqhi=16000.0, include_zero_cross=False):
+    """
+    fft and melscale method.
+    fft: nfft = [1024, 2048, 4096]; サンプルの切り取る長さを変えながらデータからnp.arrayを抽出して高速フーリエ変換を行う.
+    melscale: 周波数の次元を削減するとともに,log10の値を取っている.
+    """
+
+    feat_channels = []
+    
+    for nfft in nffts:
+        
+        feats = []
+        window = signal.blackmanharris(nfft)
+        filt = mel(song.samplerate, nfft, mel_nband, mel_freqlo, mel_freqhi)
+        
+        # get normal frame
+        frame = make_frame(song.data, nhop, nfft)
+        # print(frame.shape)
+
+        # melscaling
+        processedframe = fft(window*frame)[:, :nfft//2+1]
+        processedframe = np.dot(filt, np.transpose(np.abs(processedframe)**2))
+        processedframe = 20*np.log10(processedframe+0.1)
+        # print(processedframe.shape)
+
+        feat_channels.append(processedframe)
+    
+    if include_zero_cross:
+        song.zero_crossing = np.where(np.diff(np.sign(song.data)))[0]
+        print(song.zero_crossing)
+    
+    return np.array(feat_channels)
+
+
+@jit(parallel=True)
+def multi_fft_and_melscale(songs, nhop=512, nffts=[1024, 2048, 4096], mel_nband=80, mel_freqlo=27.5, mel_freqhi=16000.0, include_zero_cross=False):
+    
+    for i in prange(len(songs)):
+        songs[i].feats = fft_and_melscale(songs[i], nhop, nffts, mel_nband, mel_freqlo, mel_freqhi)
+
+
+def milden(data):
+    """put smaller value(0.25) to plus minus 1 frame."""
+    
+    for i in range(data.shape[0]):
+        
+        if data[i] == 1:
+            if i > 0:
+                data[i-1] = 0.25
+            if i < data.shape[0] - 1:
+                data[i+1] = 0.25
+        
+        if data[i] == 0.26:
+            if i > 0:
+                data[i-1] = 0.1
+            if i < data.shape[0] - 1:
+                data[i+1] = 0.1
+    
+    return data
+
+
+def smooth(x, window_len=11, window='hanning'):
+    
+    if x.ndim != 1:
+        raise ValueError
+
+    if x.size < window_len:
+        raise ValueError
+
+    if window_len < 3:
+        return x
+
+    if not window in ['flat', 'hanning', 'hamming', 'bartlett', 'blackman']:
+        raise ValueError
+
+    s = np.r_[x[window_len-1:0:-1], x, x[-2:-window_len-1:-1]]
+    # print(len(s))
+    if window == 'flat':  # moving average
+        w = np.ones(window_len, 'd')
+    else:
+        w = eval('np.'+window+'(window_len)')
+
+    y = np.convolve(w/w.sum(), s, mode='valid')
+    
+    return y
+
+
+
+def music_for_validation(serv, deletemusic=True, verbose=False, difficulty=1):
+
+    song = Audio(glob(serv+"/*.ogg")[0], stereo=False)
+    song.import_tja(glob(serv+"/*.tja")[-1], difficulty=difficulty)
+    song.feats = fft_and_melscale(song, nhop=512, nffts=[1024, 2048, 4096], mel_nband=80, mel_freqlo=27.5, mel_freqhi=16000.0, include_zero_cross=False)
+
+    if deletemusic:
+        song.data = None
+    with open('./data/pickles/val_data.pickle', mode='wb') as f:
+        pickle.dump(song, f)
+
+
+def music_for_train(serv, deletemusic=True, verbose=False, difficulty=0, diff=False, nhop=512, nffts=[1024, 2048, 4096], mel_nband=80, mel_freqlo=27.5, mel_freqhi=16000.0, include_zero_cross=False):
+    
+    songplaces = glob(serv)
+    songs = []
+    
+    for songplace in songplaces:
+        
+        if verbose:
+            print(songplace)
+        
+        song = Audio(glob(songplace+"/*.ogg")[0])
+        song.import_tja(glob(songplace+"/*.tja")[-1], difficulty=difficulty, diff=True)
+        song.data = (song.data[:, 0]+song.data[:, 1])/2
+        songs.append(song)
+
+    multi_fft_and_melscale(songs, nhop, nffts, mel_nband, mel_freqlo, mel_freqhi, include_zero_cross=include_zero_cross)
+    
+    if deletemusic:
+        for song in songs:
+            song.data = None
+    
+    with open('./data/pickles/train_data.pickle', mode='wb') as f:
+        pickle.dump(songs, f)
+
+def music_for_train_reduced(serv, deletemusic=True, verbose=False, difficulty=0, diff=False, nhop=512, nffts=[1024, 2048, 4096], mel_nband=80, mel_freqlo=27.5, mel_freqhi=16000.0, include_zero_cross=False):
+    
+    songplaces = glob(serv)
+    songs = []
+    
+    for songplace in songplaces:
+        
+        if verbose:
+            print(songplace)
+        
+        song = Audio(glob(songplace+"/*.ogg")[0])
+        song.import_tja(glob(songplace+"/*.tja")[-1], difficulty=difficulty, diff=True)
+        song.data = (song.data[:, 0]+song.data[:, 1])/2
+        songs.append(song)
+
+    multi_fft_and_melscale(songs, nhop, nffts, mel_nband, mel_freqlo, mel_freqhi, include_zero_cross=include_zero_cross)
+    
+    if deletemusic:
+        for song in songs:
+            song.data = None
+    
+    with open('./data/pickles/train_reduced.pickle', mode='wb') as f:
+        pickle.dump(songs, f)
+
+
+def music_for_test(serv, deletemusic=True, verbose=False):
+
+    song = Audio(glob(serv+"/*.ogg")[0], stereo=False)
+    # song.import_tja(glob(serv+"/*.tja")[-1])
+    song.feats = fft_and_melscale(song, include_zero_cross=False)
+    with open('./data/pickles/test_data.pickle', mode='wb') as f:
+        pickle.dump(song, f)
+
+
+if __name__ == "__main__":
+
+    if sys.argv[1] == 'train':
+        print("preparing all train data processing...")
+        serv = "./data/train/*"
+        music_for_train(serv, verbose=True, difficulty=0, diff=True)
+        print("all train data processing done!")    
+
+    if sys.argv[1] == 'test':
+        print("test data proccesing...")
+        serv = "./data/test/"
+        music_for_test(serv)
+        print("test data processing done!")
+
+