cnn_train.py 594 o
from cnn.model import *
from cnn.music_processor import *
from glob import glob
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net = convNet()
net = net.to(device)
songplaces = glob('./data/pickles/*.pickle')
songs = []
for songplace in songplaces:
with open(songplace, mode='rb') as f:
song = pickle.load(f)
songs.append(song)
minibatch = 128
soundlen = 15
epoch = 100
net.train(songs=songs, minibatch=minibatch, val_song=None, epoch=epoch, device=device, soundlen=soundlen, save_place='./models/model.pth', log='./data/log/log.txt')