Sélectionner une révision Git
infer.py 1,38 Kio
from model import *
from music_processor import *
from assUtils import AssWriter
import pickle
import numpy as np
from scipy.signal import argrelmax
from librosa.util import peak_pick
from librosa.onset import onset_detect
def segment(songfile):
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net = convNet()
net = net.to(device)
with open('./data/pickles/train_data.pickle', mode='rb') as f:
songs = pickle.load(f)
if torch.cuda.is_available():
net.load_state_dict(torch.load('./models/model.pth'))
else:
net.load_state_dict(torch.load('./models/model.pth', map_location='cpu'))
song = songs[0]
inference = net.infer(song.feats, device, minibatch=4192)
inference = np.reshape(inference, (-1))
return detection(inference, song.samplerate)
def detection(inference, samplerate):
inference = smooth(inference, 5)
timestamp = (peak_pick(inference, pre_max=1, post_max=2, pre_avg=4, post_avg=5, delta=0.05, wait=3)) # 実際は7フレーム目のところの音
timestamp = timestamp*512/samplerate
return timestamp
if __name__ == '__main__':
onsets = segment(sys.argv[1])
syls = [[t, ''] for t in onsets]
print(syls)
writer = AssWriter()
writer.openAss("./media/test.ass")
writer.writeHeader()
writer.writeSyls(syls)
writer.closeAss()