Skip to content
Extraits de code Groupes Projets
Sélectionner une révision Git
  • 28ecec5460706f91ac399f6b86031c51b586f2a7
  • prod par défaut
  • dev
  • mornico
  • master protégée
5 résultats

admin.cpython-37.pyc

Blame
  • model.py 7,12 Kio
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim
    import numpy as np
    from tqdm import tqdm
    from cnn.music_processor import *
    
    """
    On the paper,
    Starting from a stack of three spectrogram excerpts,
    convolution and max-pooling in turns compute a set of 20 feature maps 
    classified with a fully-connected network.
    """
    
    class convNet(nn.Module):
        """
        copies the neural net used in a paper.
        "Improved musical onset detection with Convolutional Neural Networks".
        src: https://ieeexplore.ieee.org/document/6854953
        """
    
        def __init__(self):
    
            super(convNet, self).__init__()
            # model
            self.conv1 = nn.Conv2d(3, 10, (3, 7))
            self.conv2 = nn.Conv2d(10, 20, 3)
            self.fc1 = nn.Linear(1120, 256)
            self.fc2 = nn.Linear(256, 120)
            self.fc3 = nn.Linear(120, 1)
    
    
        def forward(self, x, istraining=False, minibatch=1):
    
            x = F.max_pool2d(F.relu(self.conv1(x)), (3, 1))
            x = F.max_pool2d(F.relu(self.conv2(x)), (3, 1))
            x = F.dropout(x.view(minibatch, -1), training=istraining)
            x = F.dropout(F.relu(self.fc1(x)), training=istraining)
            x = F.dropout(F.relu(self.fc2(x)), training=istraining)
    
            return F.sigmoid(self.fc3(x))
    
    
        def train_data_builder(self, feats, answer, major_note_index, samplerate, soundlen=15, minibatch=1, split=0.2):
            """
            Args:
                feats: song.feats; Audio module
                answers: song.answers; Audio module
                major_note_index: answer labels; corresponding to feats
                samplerate: song.samplerate; Audio module
                soundlen: =15. 学習モデルに渡す画像データの横方向の長さ.ここでは(80 * 15)サイズのデータを使用している
                minibatch: training minibatch
                split: =1. 
            Variables:
                minspace: minimum space between major note indexs
                maxspace: maximum space between major note indexs
                idx: index of major_note_index or feats
                dist: distance of two notes
            """
    
            # acceptable interval in seconds
            minspace = 0.1
            maxspace = 0.7
    
            idx = np.random.permutation(major_note_index.shape[0] - soundlen) + soundlen // 2
            X, y = [], []
            cnt = 0
            
            for i in range(int(idx.shape[0] * split)):
    
                dist = major_note_index[idx[i] + 1] - major_note_index[idx[i]]  # distinguish by this value
                
                if dist < maxspace * samplerate / 512 and dist > minspace * samplerate / 512:    
                    for j in range(-1, dist + 2):
                        X.append(feats[:, :, major_note_index[idx[i]] - soundlen // 2 + j : major_note_index[idx[i]] + soundlen // 2 + j + 1])
                        y.append(answer[major_note_index[idx[i]] + j])
                        cnt += 1
                        
                        if cnt % minibatch == 0:
                            yield (torch.from_numpy(np.array(X)).float(), torch.from_numpy(np.array(y)).float())
                            X, y = [], []
    
    
        def infer_data_builder(self, feats, soundlen=15, minibatch=1):
            
            x = []
            
            for i in range(feats.shape[2] - soundlen):
                x.append(feats[:, :, i:i+soundlen])
            
                if (i + 1) % minibatch == 0:
                    yield (torch.from_numpy(np.array(x)).float())
                    x = []
            
            if len(x) != 0:
                yield (torch.from_numpy(np.array(x)).float())
    
    
        def train(self, songs, minibatch, epoch, device, soundlen=15, val_song=None, save_place='./models/model.pth', log='./log/log.txt'):
            """
            Args:
                songs: the list of song
                minibatch: minibatch value
                epoch: number of train 
                device: cpu / gpu
                soundlen: width of one train data's image
                val_song: validation song, if you wanna validation while training, give a path of validation song data.
                save_place: save place path
                log: log place path
                don-ka: don(1) or ka(2) or both(0), usually, firstly, train don, then, train ka.
            """
    
            for song in songs:
                
                timing = np.array([syl[0] for syl in song.timestamp])
                syllable  = np.array([syl[1] for syl in song.timestamp])
                song.answer = np.zeros((song.feats.shape[2]))
    
               
                song.major_note_index = np.rint(timing[np.where(syllable != "")] * song.samplerate/512).astype(np.int32)
                
                song.major_note_index = np.delete(song.major_note_index, np.where(song.major_note_index >= song.feats.shape[2]))
    
                song.answer[song.major_note_index] = 1
    
                song.answer = milden(song.answer)
    
            # training
            optimizer = optim.SGD(self.parameters(), lr=0.02)
            criterion = nn.MSELoss()
            running_loss = 0
            val_loss = 0
    
            for i in range(epoch):
                for song in songs:
                    for X, y in self.train_data_builder(song.feats, song.answer, song.major_note_index, song.samplerate, soundlen, minibatch, split=0.2):
                        optimizer.zero_grad()
                        output = self(X.to(device), istraining=True, minibatch=minibatch)
                        target = y.to(device)
                        loss = criterion(output.squeeze(), target)
                        loss.backward()
                        optimizer.step()
                        running_loss += loss.data.item()
    
                with open(log, 'a') as f:
                    print("epoch: %.d running_loss: %.10f " % (i+1, running_loss), file=f)
    
                print("epoch: %.d running_loss: %.10f" % (i+1, running_loss))
                
                running_loss = 0    
    
                if val_song:
                    inference = torch.from_numpy(self.infer(val_song.feats, device, minibatch=512)).to(device)
                    target = torch.from_numpy(val_song.answer[:-soundlen]).float().to(device)
                    loss = criterion(inference.squeeze(), target)
                    val_loss = loss.data.item()
    
                    with open(log, 'a') as f:
                        print("val_loss: %.10f " % (val_loss), file=f)
    
            torch.save(self.state_dict(), save_place)
    
    
        def infer(self, feats, device, minibatch=1):
    
            with torch.no_grad():
                inference = None
                for x in tqdm(self.infer_data_builder(feats, minibatch=minibatch), total=feats.shape[2]//minibatch):
                    output = self(x.to(device), minibatch=x.shape[0])
                    if inference is not None:
                        inference = np.concatenate((inference, output.cpu().numpy().reshape(-1)))
                    else:
                        inference = output.cpu().numpy().reshape(-1)
                
                return np.array(inference).reshape(-1)
    
    
    if __name__ == '__main__':
    
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        net = convNet()
        net = net.to(device)
               
        with open('./data/pickles/train_data.pickle', mode='rb') as f:
            songs = pickle.load(f)
    
        minibatch = 128
        soundlen = 15
        epoch = 100
    
        
        net.train(songs=songs, minibatch=minibatch, val_song=None, epoch=epoch, device=device, soundlen=soundlen, save_place='./models/model.pth', log='./data/log/log.txt')