Skip to content
Extraits de code Groupes Projets
cnn_train.py 594 o
from cnn.model import *
from cnn.music_processor import *
from glob import glob


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net = convNet()
net = net.to(device)

songplaces = glob('./data/pickles/*.pickle')     
songs = []   

for songplace in songplaces:
    with open(songplace, mode='rb') as f:
        song = pickle.load(f)
        songs.append(song)

minibatch = 128
soundlen = 15
epoch = 100


net.train(songs=songs, minibatch=minibatch, val_song=None, epoch=epoch, device=device, soundlen=soundlen, save_place='./models/model.pth', log='./data/log/log.txt')