from cnn.model import * from cnn.music_processor import * from glob import glob device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') net = convNet() net = net.to(device) songplaces = glob('./data/pickles/*.pickle') songs = [] for songplace in songplaces: with open(songplace, mode='rb') as f: song = pickle.load(f) songs.append(song) minibatch = 128 soundlen = 15 epoch = 100 net.train(songs=songs, minibatch=minibatch, val_song=None, epoch=epoch, device=device, soundlen=soundlen, save_place='./models/model.pth', log='./data/log/log.txt')