diff --git a/README.md b/README.md
index 5adecf427a36b3e423589c27fe1fe8035f6502ca..5239f2f3ddba3a304e9d93170c8130eeeb04b4f6 100644
--- a/README.md
+++ b/README.md
@@ -69,25 +69,54 @@ $ deactivate
 Having a CUDA-capable GPU is optional, but can greatly reduce processing time in some situations.
 
 
+To use the custom phonetic mapping for Japanese Romaji, you need to update manually (for now) the g2p DB (within the venv):
+```bash
+$ cp g2p/mappings/langs/rji/* env/lib/python3.11/site-packages/g2p/mappings/langs/rji/
+
+#Then update :
+$ g2p update
+```
+
+
+
 # Use
 
 ## Autokara
 
-To execute AutoKara from scratch on a MKV video file :
+To use Autokara, you need :
+ - A media file of the song (video, or pre-extracted vocals)
+ - An ASS file with the lyrics, split by syllable
+
+To execute AutoKara on a MKV video file and an ASS file containing the lyrics (ASS will be overwritten):
 ```bash
-$ python autokara.py video.mkv output.ass
+$ python autokara.py video.mkv lyrics.ass
 ```
 
-To execute AutoKara with existing syl splits and line timings :
+To output to a different file (and keep the original) :
 ```bash
-$ python autokara.py video.mkv output.ass --ref reference.ass
+$ python autokara.py video.mkv lyrics.ass -o output.ass
 ```
 
-To execute AutoKara on a (pre-extracted) WAV vocals file :
+To execute AutoKara on a (pre-extracted) WAV (or OGG, MP3, ...) vocals file, pass the `--vocals` flag :
 ```bash
 $ python autokara.py vocals.wav output.ass --vocals
 ```
 
+To use a phonetic transcription optimized for a specific language, use `--lang` (or `-l`) :
+```bash
+$ python autokara.py vocals.wav output.ass --lang jp
+```
+Available languages are :
+```
+jp : Japanese Romaji (default)
+en : English
+```
+
+Full help for all options is available with :
+```bash
+$ python autokara.py -h
+```
+
 ## Useful scripts
 
 To only extract .wav audio from a MKV file :
@@ -110,6 +139,12 @@ Batch preprocessing (vocals + ASS extraction) of all videos in a directory :
 $ ./preprocess_media.sh video_folder output_folder
 ```
 
+A visualization tool, mainly intended for debug.
+Does the same as autokara.py, but instead of writing to a file, plots a graphic with onset times, spectrogram, probability curves,... 
+Does not work on video files, only separated vocals audio files
+```bash
+$ python plot_syls.py vocals.wav lyrics.ass
+```
 
 
 
diff --git a/autokara.py b/autokara.py
index 9cbf49de3b534b67cb99d13a2f7702ec4725a086..8fe6f2e6eafc35e54f1f1c6dab83e3f1bab72dce 100644
--- a/autokara.py
+++ b/autokara.py
@@ -4,20 +4,23 @@ import demucs.separate
 import subprocess
 import shlex
 from pathlib import Path
-from assUtils import AssWriter, getSyls
 
+from autosyl.assUtils import AssWriter, getSyls, getHeader
 from autosyl.segment import segment
 
 
 parser = argparse.ArgumentParser(description='AutoKara - Automatic karaoke timing tool')
 parser.add_argument("source_file", type=str, help="The video/audio file to time")
-parser.add_argument("ass_file", type=str, help="The ASS file in which to output the karaoke")
+parser.add_argument("ass_file", type=str, help="The ASS file with lyrics to time")
 parser.add_argument("--vocals", action="store_true", help="Treat the input as vocals file, i.e. do not perform vocals extraction")
-parser.add_argument("--ref", help="Use an ASS file as reference")
+parser.add_argument("-o", "--output", help="Write output to specified file. If absent, overwrite source file")
+parser.add_argument("-v","--verbose", action="store_true", help="Increased verbosity")
+parser.add_argument("-l","--lang", help="Select language to use (default is Japanese Romaji)")
 
 args = parser.parse_args()
 
 ass_file = args.ass_file
+verbose = args.verbose
 
 if not args.vocals :
     print("Extracting audio from video file...")
@@ -40,19 +43,26 @@ if not args.vocals :
 else:
     vocals_file = args.source_file
 
-if args.ref:
-    reference_syls = getSyls(args.ref)
-else:
-    reference_syls = None
+
 
 print("Identifying syl starts...")
-syls = segment(vocals_file, reference_syls=reference_syls)
+
+
+if verbose:
+    print("Retrieving syls from lyrics...")
+reference_syls, line_meta = getSyls(ass_file)
+
+if verbose:
+    print("Starting syl detection...")
+syls = segment(vocals_file, reference_syls=reference_syls, verbose=verbose, language=args.lang)
 print(syls)
+print(line_meta)
 
 print("Syls found, writing ASS file...")
+header = getHeader(ass_file)
 writer = AssWriter()
-writer.openAss(ass_file)
-writer.writeHeader()
-writer.writeSyls(syls)
+writer.openAss(args.output if args.output else ass_file)
+writer.writeHeader(header=header)
+writer.writeSyls(syls, line_meta)
 writer.closeAss()
 
diff --git a/autosyl/LyricsAlignment/LICENSE b/autosyl/LyricsAlignment/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..52ee0f284c95ab6f0b82c36a5e155845cb0a20df
--- /dev/null
+++ b/autosyl/LyricsAlignment/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 Jiawen Huang
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/autosyl/LyricsAlignment/checkpoints/checkpoint_BDR b/autosyl/LyricsAlignment/checkpoints/checkpoint_BDR
new file mode 100644
index 0000000000000000000000000000000000000000..d94b94d9bf90621d49d3be8feed1c0f3194611d0
Binary files /dev/null and b/autosyl/LyricsAlignment/checkpoints/checkpoint_BDR differ
diff --git a/autosyl/LyricsAlignment/checkpoints/checkpoint_Baseline b/autosyl/LyricsAlignment/checkpoints/checkpoint_Baseline
new file mode 100644
index 0000000000000000000000000000000000000000..f547d3e3b5968012e39411ce08552f989a4fa74f
Binary files /dev/null and b/autosyl/LyricsAlignment/checkpoints/checkpoint_Baseline differ
diff --git a/autosyl/LyricsAlignment/checkpoints/checkpoint_MTL b/autosyl/LyricsAlignment/checkpoints/checkpoint_MTL
new file mode 100644
index 0000000000000000000000000000000000000000..5fc7a6ec9ebae369229d215811451f3aa1d71b08
Binary files /dev/null and b/autosyl/LyricsAlignment/checkpoints/checkpoint_MTL differ
diff --git a/autosyl/LyricsAlignment/model.py b/autosyl/LyricsAlignment/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6fd66b10cb934ee1911b50d35b6bbcd13a9320f
--- /dev/null
+++ b/autosyl/LyricsAlignment/model.py
@@ -0,0 +1,236 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchaudio
+import warnings
+
+from autosyl.LyricsAlignment.utils import notes_to_pc
+
+# following FFT parameters are designed for a 22.5k sampling rate
+sr = 22050
+n_fft = 512
+resolution = 256/22050*3
+
+with warnings.catch_warnings():
+    warnings.simplefilter("ignore")
+    train_audio_transforms = nn.Sequential(
+        torchaudio.transforms.MelSpectrogram(sample_rate=sr, n_mels=128, n_fft=n_fft),
+    )
+
+def data_processing(data):
+    spectrograms = []
+    phones = []
+    pcs = []
+    input_lengths = []
+    phone_lengths = []
+    for (waveform, _, _, phone, notes) in data:
+        waveform = torch.Tensor(waveform)
+        # convert to Mel
+        spec = train_audio_transforms(waveform).squeeze(0).transpose(0, 1) # time x n_mels
+        spectrograms.append(spec)
+
+        # get phoneme list (mapped to integers)
+        phone = torch.Tensor(phone)
+        phones.append(phone)
+
+        # get the pitch contour
+        # the number 3 here and below is due the the maxpooling along the frequency axis
+        pc = notes_to_pc(notes, resolution, spec.shape[0] // 3)
+        pcs.append(pc)
+
+        input_lengths.append(spec.shape[0]//3)
+        phone_lengths.append(len(phone))
+
+    spectrograms = nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3)
+    phones = nn.utils.rnn.pad_sequence(phones, batch_first=True)
+
+    return spectrograms, phones, input_lengths, phone_lengths, torch.LongTensor(pcs)
+
+class CNNLayerNorm(nn.Module):
+    '''Layer normalization built for cnns input'''
+
+    def __init__(self, n_feats):
+        super(CNNLayerNorm, self).__init__()
+        self.layer_norm = nn.LayerNorm(n_feats)
+
+    def forward(self, x):
+        # x (batch, channel, feature, time)
+        x = x.transpose(2, 3).contiguous()  # (batch, channel, time, feature)
+        x = self.layer_norm(x)
+        return x.transpose(2, 3).contiguous()  # (batch, channel, feature, time)
+
+
+class ResidualCNN(nn.Module):
+    '''Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf
+        except with layer norm instead of batch norm
+    '''
+
+    def __init__(self, in_channels, out_channels, kernel, stride, dropout, n_feats):
+        super(ResidualCNN, self).__init__()
+
+        self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding=kernel // 2)
+        self.cnn2 = nn.Conv2d(out_channels, out_channels, kernel, stride, padding=kernel // 2)
+        self.dropout1 = nn.Dropout(dropout)
+        self.dropout2 = nn.Dropout(dropout)
+        self.layer_norm1 = CNNLayerNorm(n_feats)
+        self.layer_norm2 = CNNLayerNorm(n_feats)
+
+    def forward(self, x):
+        residual = x  # (batch, channel, feature, time)
+        x = self.layer_norm1(x)
+        x = F.gelu(x)
+        x = self.dropout1(x)
+        x = self.cnn1(x)
+        x = self.layer_norm2(x)
+        x = F.gelu(x)
+        x = self.dropout2(x)
+        x = self.cnn2(x)
+        x += residual
+        return x  # (batch, channel, feature, time)
+
+
+class BidirectionalLSTM(nn.Module):
+
+    def __init__(self, rnn_dim, hidden_size, dropout, batch_first):
+        super(BidirectionalLSTM, self).__init__()
+
+        self.BiLSTM = nn.LSTM(
+            input_size=rnn_dim, hidden_size=hidden_size,
+            num_layers=1, batch_first=batch_first, bidirectional=True)
+        self.dropout = nn.Dropout(dropout)
+
+    def forward(self, x):
+        x, _ = self.BiLSTM(x)
+        x = self.dropout(x)
+        return x
+
+class AcousticModel(nn.Module):
+    '''
+        The acoustic model: baseline and MTL share the same class,
+        the only difference is the target dimension of the last fc layer
+    '''
+
+    def __init__(self, n_cnn_layers, rnn_dim, n_class, n_feats, stride=1, dropout=0.1):
+        super(AcousticModel, self).__init__()
+
+        self.n_class = n_class
+        if isinstance(n_class, int):
+            target_dim = n_class
+        else:
+            target_dim = n_class[0] * n_class[1]
+
+        self.cnn_layers = nn.Sequential(
+            nn.Conv2d(1, n_feats, 3, stride=stride, padding=3 // 2),
+            nn.ReLU()
+        )
+
+        self.rescnn_layers = nn.Sequential(*[
+            ResidualCNN(n_feats, n_feats, kernel=3, stride=1, dropout=dropout, n_feats=128)
+            for _ in range(n_cnn_layers)
+        ])
+
+        self.maxpooling = nn.MaxPool2d(kernel_size=(2, 3))
+        self.fully_connected = nn.Linear(n_feats * 64, rnn_dim)
+
+        self.bilstm = nn.Sequential(
+            BidirectionalLSTM(rnn_dim=rnn_dim, hidden_size=rnn_dim, dropout=dropout, batch_first=True),
+            BidirectionalLSTM(rnn_dim=rnn_dim * 2, hidden_size=rnn_dim, dropout=dropout, batch_first=False),
+            BidirectionalLSTM(rnn_dim=rnn_dim * 2, hidden_size=rnn_dim, dropout=dropout, batch_first=False)
+        )
+
+        self.classifier = nn.Sequential(
+            nn.Linear(rnn_dim * 2, target_dim)
+        )
+
+    def forward(self, x):
+        x = self.cnn_layers(x)
+        x = self.rescnn_layers(x)
+        x = self.maxpooling(x)
+
+        sizes = x.size()
+        x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3])  # (batch, feature, time)
+        x = x.transpose(1, 2)  # (batch, time, feature)
+        x = self.fully_connected(x)
+
+        x = self.bilstm(x)
+        x = self.classifier(x)
+
+        if isinstance(self.n_class, tuple):
+            x = x.view(sizes[0], sizes[3], self.n_class[0], self.n_class[1])
+
+        return x
+
+class MultiTaskLossWrapper(nn.Module):
+    def __init__(self):
+        super(MultiTaskLossWrapper, self).__init__()
+
+        self.criterion_lyrics = nn.CTCLoss(blank=40, zero_infinity=True)
+        self.criterion_melody = nn.CrossEntropyLoss()
+
+    def forward(self, mat3d, lyrics_gt, melody_gt):
+
+        n_batch, n_frame, n_ch, n_p = mat3d.shape # (batch, time, phone, pitch)
+
+        y_lyrics = torch.sum(mat3d, dim=3) # (batch, time, n_ch)
+        y_melody = torch.sum(mat3d, dim=2) # (batch, time, n_p)
+
+        y_lyrics = F.log_softmax(y_lyrics, dim=2)
+        y_lyrics = y_lyrics.transpose(0, 1) # (time, batch, n_ch) reshape for CTC
+        labels, input_lengths, label_lengths = lyrics_gt
+        loss_lyrics = self.criterion_lyrics(y_lyrics, labels, input_lengths, label_lengths)
+
+        y_melody = y_melody.transpose(1, 2)  # (batch, n_p, time)
+        loss_melody = self.criterion_melody(y_melody, melody_gt)
+
+        return loss_lyrics, loss_melody
+
+
+class BoundaryDetection(nn.Module):
+
+    def __init__(self, n_cnn_layers, rnn_dim, n_class, n_feats, stride=1, dropout=0.1):
+        super(BoundaryDetection, self).__init__()
+
+        self.n_class = n_class
+
+        # n residual cnn layers with filter size of 32
+        self.cnn_layers = nn.Sequential(
+            nn.Conv2d(1, n_feats, 3, stride=stride, padding=3 // 2),
+            nn.ReLU()
+        )
+
+        self.rescnn_layers = nn.Sequential(*[
+            ResidualCNN(n_feats, n_feats, kernel=3, stride=1, dropout=dropout, n_feats=128)
+            for _ in range(n_cnn_layers)
+        ])
+
+        self.maxpooling = nn.MaxPool2d(kernel_size=(2, 3))
+        self.fully_connected = nn.Linear(n_feats * 64, rnn_dim) # add a linear layer
+
+        self.bilstm_layers = nn.Sequential(
+            BidirectionalLSTM(rnn_dim=rnn_dim, hidden_size=rnn_dim, dropout=dropout, batch_first=True),
+            BidirectionalLSTM(rnn_dim=rnn_dim * 2, hidden_size=rnn_dim, dropout=dropout, batch_first=False),
+            BidirectionalLSTM(rnn_dim=rnn_dim * 2, hidden_size=rnn_dim, dropout=dropout, batch_first=False)
+        )
+
+        self.classifier = nn.Sequential(
+            nn.Linear(rnn_dim * 2, n_class)  # birnn returns rnn_dim*2
+        )
+
+    def forward(self, x):
+        x = self.cnn_layers(x)
+        x = self.rescnn_layers(x)
+        x = self.maxpooling(x)
+
+        sizes = x.size()
+        x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3])  # (batch, feature, time)
+        x = x.transpose(1, 2)  # (batch, time, feature)
+        x = self.fully_connected(x)
+
+        x = self.bilstm_layers(x)
+
+        x = self.classifier(x)
+        x = x.view(sizes[0], sizes[3], self.n_class)
+
+        x = torch.sigmoid(x)
+
+        return x
\ No newline at end of file
diff --git a/autosyl/LyricsAlignment/utils.py b/autosyl/LyricsAlignment/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec8aab853dbd5f016f5e5e9d55d020c7f93cdd7a
--- /dev/null
+++ b/autosyl/LyricsAlignment/utils.py
@@ -0,0 +1,502 @@
+import os
+import soundfile
+import torch
+import numpy as np
+import librosa
+import string
+import warnings
+import g2p_en
+from g2p import make_g2p
+
+
+
+phone_dict = ['AA', 'AE', 'AH', 'AO', 'AW', 'AY', 'B', 'CH', 'D', 'DH', 'EH', 'ER', 'EY', 'F', 'G', 'HH', 'IH', 'IY',
+             'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OY', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UW', 'V', 'W', 'Y',
+             'Z', 'ZH', ' ']
+phone2int = {phone_dict[i]: i for i in range(len(phone_dict))}
+
+
+class G2p_Wrapper():
+
+    def __init__(self, language="jp"):
+        if language == "en":
+            self.transducer = g2p_en.G2p()
+        else:                                                   # Only Japanese Romaji for now...
+            self.transducer = make_g2p('rji', 'rji-eng-arpa')
+
+        self.language = language
+
+    def __call__(self, word):
+        if self.language == "en":
+            return self.transducer(word)
+        else:
+            return self.transducer(word).output_string.split()
+
+#g2p = G2p_Wrapper(language="jp")
+
+
+
+def my_collate(batch):
+    audio, targets, seqs = zip(*batch)
+    audio = np.array(audio)
+    targets = list(targets)
+    seqs = list(seqs)
+    return audio, targets, seqs
+
+def worker_init_fn(worker_id):
+    np.random.seed(np.random.get_state()[1][0] + worker_id)
+
+def find_separated_vocal(fileid):
+
+    pass
+
+def load(path, sr=22050, mono=True, offset=0., duration=None):
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore")
+        y, curr_sr = librosa.load(path, sr=sr, mono=mono, res_type='kaiser_fast', offset=offset, duration=duration)
+
+    if len(y.shape) == 1:
+        y = y[np.newaxis, :] # (channel, sample)
+
+    return y, curr_sr
+
+def load_lyrics(lyrics_file):
+    from string import ascii_lowercase
+    d = {ascii_lowercase[i]: i for i in range(26)}
+    d["'"] = 26
+    d[" "] = 27
+    d["~"] = 28
+
+    # process raw
+    with open(lyrics_file + '.raw.txt', 'r') as f:
+        raw_lines = f.read().splitlines()
+    raw_lines = ["".join([c for c in line.lower() if c in d.keys()]).strip() for line in raw_lines]
+    raw_lines = [" ".join(line.split()) for line in raw_lines if len(line) > 0]
+    # concat
+    full_lyrics = " ".join(raw_lines)
+
+    # split to words
+    with open(lyrics_file + '.words.txt', 'r') as f:
+        words_lines = f.read().splitlines()
+    idx = []
+    last_end = 0
+    for i in range(len(words_lines)):
+        word = words_lines[i]
+        try:
+            assert (word[0] in ascii_lowercase)
+        except:
+            # print(word)
+            pass
+        new_word = "".join([c for c in word.lower() if c in d.keys()])
+        offset = full_lyrics[last_end:].find(new_word)
+        assert (offset >= 0)
+        assert (new_word == full_lyrics[last_end + offset:last_end + offset + len(new_word)])
+        idx.append([last_end + offset, last_end + offset + len(new_word)])
+        last_end += offset + len(new_word)
+
+    # beginning of a line
+    idx_line = []
+    last_end = 0
+    for i in range(len(raw_lines)):
+        line = raw_lines[i]
+        offset = full_lyrics[last_end:].find(line)
+        assert (offset >= 0)
+        assert (line == full_lyrics[last_end + offset:last_end + offset + len(line)])
+        idx_line.append([last_end + offset, last_end + offset + len(line)])
+        last_end += offset + len(line)
+
+    return full_lyrics, words_lines, idx, idx_line, raw_lines
+
+def write_wav(path, audio, sr):
+    soundfile.write(path, audio.T, sr, "PCM_16")
+
+def gen_phone_gt(words, raw_lines, language="jp"):
+
+    print(f"Translating lyrics to phonemes, language chosen : {language:s}")
+    g2p = G2p_Wrapper(language=language)
+
+
+    # helper function
+    def getsubidx(x, y):  # find y in x
+        l1, l2 = len(x), len(y)
+        for i in range(l1 - l2 + 1):
+            if x[i:i + l2] == y:
+                return i
+    words_p = []
+    lyrics_p = []
+    for word in words:
+        out = g2p(word)
+        out = [phone if phone[-1] not in string.digits else phone[:-1] for phone in out]
+        words_p.append(out)
+        if len(lyrics_p) > 0:
+            lyrics_p.append(' ')
+        lyrics_p += out
+
+    len_words_p = [len(phones) for phones in words_p]
+    idx_in_full_p = []
+    s1 = 0
+    s2 = s1
+    for l in len_words_p:
+        s2 = s1 + l
+        idx_in_full_p.append([s1, s2])
+        s1 = s2 + 1
+
+        # beginning of a line
+        idx_line_p = []
+        last_end = 0
+        for i in range(len(raw_lines)):
+            line = []
+            line_phone = [g2p(word) for word in raw_lines[i].split()]
+            for l in line_phone:
+                line += l + [' ']
+            line = line[:-1]
+            line = [phone if phone[-1] not in string.digits else phone[:-1] for phone in line]
+            offset = getsubidx(lyrics_p[last_end:], line)
+            assert (offset >= 0)
+            assert (line == lyrics_p[last_end + offset:last_end + offset + len(line)])
+            idx_line_p.append([last_end + offset, last_end + offset + len(line)])
+            last_end += offset + len(line)
+
+    return lyrics_p, words_p, idx_in_full_p, idx_line_p
+
+class DataParallel(torch.nn.DataParallel):
+    def __init__(self, module, device_ids=None, output_device=None, dim=0):
+        super(DataParallel, self).__init__(module, device_ids, output_device, dim)
+
+    def __getattr__(self, name):
+        try:
+            return super().__getattr__(name)
+        except AttributeError:
+            return getattr(self.module, name)
+
+def save_model(model, optimizer, state, path):
+    if isinstance(model, torch.nn.DataParallel):
+        model = model.module  # save state dict of wrapped module
+    if len(os.path.dirname(path)) > 0 and not os.path.exists(os.path.dirname(path)):
+        os.makedirs(os.path.dirname(path))
+    torch.save({
+        'model_state_dict': model.state_dict(),
+        'optimizer_state_dict': optimizer.state_dict(),
+        'state': state,
+    }, path)
+
+def load_model(model, path, cuda):
+    if isinstance(model, torch.nn.DataParallel):
+        model = model.module  # load state dict of wrapped module
+    if cuda:
+        checkpoint = torch.load(path)
+    else:
+        checkpoint = torch.load(path, map_location='cpu')
+    model.load_state_dict(checkpoint['model_state_dict'])
+
+    if 'state' in checkpoint:
+        state = checkpoint['state']
+    else:
+        state = {"step": 0,
+                 "worse_epochs": 0,
+                 "epochs": checkpoint['epoch'],
+                 "best_loss": np.Inf}
+
+    return state
+
+def seed_torch(seed=0):
+    # random.seed(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed(seed)
+
+def move_data_to_device(x, device):
+    if 'float' in str(x.dtype):
+        x = torch.Tensor(x)
+    elif 'int' in str(x.dtype):
+        x = torch.LongTensor(x)
+    else:
+        return x
+
+    return x.to(device)
+
+def alignment(song_pred, lyrics, idx):
+    audio_length, num_class = song_pred.shape
+    lyrics_int = phone2seq(lyrics)
+    lyrics_length = len(lyrics_int)
+
+    s = np.zeros((audio_length, 2*lyrics_length+1)) - np.Inf
+    opt = np.zeros((audio_length, 2*lyrics_length+1))
+
+    blank = 40
+
+    # init
+    s[0][0] = song_pred[0][blank]
+    # insert eps
+    for i in np.arange(1, audio_length):
+        s[i][0] = s[i-1][0] + song_pred[i][blank]
+
+    for j in np.arange(lyrics_length):
+        if j == 0:
+            s[j+1][2*j+1] = s[j][2*j] + song_pred[j+1][lyrics_int[j]]
+            opt[j+1][2*j+1] = 1  # 45 degree
+        else:
+            s[j+1][2*j+1] = s[j][2*j-1] + song_pred[j+1][lyrics_int[j]]
+            opt[j+1][2*j+1] = 2 # 28 degree
+
+        s[j+2][2*j+2] = s[j+1][2*j+1] + song_pred[j+2][blank]
+        opt[j+2][2*j+2] = 1  # 45 degree
+
+
+    for audio_pos in np.arange(2, audio_length):
+
+        for ch_pos in np.arange(1, 2*lyrics_length+1):
+
+            if ch_pos % 2 == 1 and (ch_pos+1)/2 >= audio_pos:
+                break
+            if ch_pos % 2 == 0 and ch_pos/2 + 1 >= audio_pos:
+                break
+
+            if ch_pos % 2 == 1: # ch
+                ch_idx = int((ch_pos-1)/2)
+                # cur ch -> ch
+                a = s[audio_pos-1][ch_pos] + song_pred[audio_pos][lyrics_int[ch_idx]]
+                # last ch -> ch
+                b = s[audio_pos-1][ch_pos-2] + song_pred[audio_pos][lyrics_int[ch_idx]]
+                # eps -> ch
+                c = s[audio_pos-1][ch_pos-1] + song_pred[audio_pos][lyrics_int[ch_idx]]
+                if a > b and a > c:
+                    s[audio_pos][ch_pos] = a
+                    opt[audio_pos][ch_pos] = 0
+                elif b >= a and b >= c:
+                    s[audio_pos][ch_pos] = b
+                    opt[audio_pos][ch_pos] = 2
+                else:
+                    s[audio_pos][ch_pos] = c
+                    opt[audio_pos][ch_pos] = 1
+
+            if ch_pos % 2 == 0: # eps
+                # cur ch -> ch
+                a = s[audio_pos-1][ch_pos] + song_pred[audio_pos][blank]
+                # eps -> ch
+                c = s[audio_pos-1][ch_pos-1] + song_pred[audio_pos][blank]
+                if a > c:
+                    s[audio_pos][ch_pos] = a
+                    opt[audio_pos][ch_pos] = 0
+                else:
+                    s[audio_pos][ch_pos] = c
+                    opt[audio_pos][ch_pos] = 1
+
+    score = s[audio_length-1][2*lyrics_length]
+
+    # retrive optimal path
+    path = []
+    x = audio_length-1
+    y = 2*lyrics_length
+    path.append([x, y])
+    while x > 0 or y > 0:
+        if opt[x][y] == 1:
+            x -= 1
+            y -= 1
+        elif opt[x][y] == 2:
+            x -= 1
+            y -= 2
+        else:
+            x -= 1
+        path.append([x, y])
+
+    path = list(reversed(path))
+    word_align = []
+    path_i = 0
+
+    word_i = 0
+    while word_i < len(idx):
+        # e.g. "happy day"
+        # find the first time "h" appears
+        if path[path_i][1] == 2*idx[word_i][0]+1:
+            st = path[path_i][0]
+            # find the first time " " appears after "h"
+            while  path_i < len(path)-1 and (path[path_i][1] != 2*idx[word_i][1]+1):
+                path_i += 1
+            ed = path[path_i][0]
+            # append
+            word_align.append([st, ed])
+            # move to next word
+            word_i += 1
+        else:
+            # move to next audio frame
+            path_i += 1
+
+    return word_align, score
+
+def alignment_bdr(song_pred, lyrics, idx, bdr_pred, line_start):
+    audio_length, num_class = song_pred.shape
+    lyrics_int = phone2seq(lyrics)
+    lyrics_length = len(lyrics_int)
+
+    s = np.zeros((audio_length, 2*lyrics_length+1)) - np.Inf
+    opt = np.zeros((audio_length, 2*lyrics_length+1))
+
+    blank = 40
+
+    # init
+    s[0][0] = song_pred[0][blank]
+    # insert eps
+    for i in np.arange(1, audio_length):
+        s[i][0] = s[i-1][0] + song_pred[i][blank]
+
+    for j in np.arange(lyrics_length):
+        if j == 0:
+            s[j+1][2*j+1] = s[j][2*j] + song_pred[j+1][lyrics_int[j]]
+            opt[j+1][2*j+1] = 1  # 45 degree
+        else:
+            s[j+1][2*j+1] = s[j][2*j-1] + song_pred[j+1][lyrics_int[j]]
+            opt[j+1][2*j+1] = 2 # 28 degree
+        if j in line_start:
+            s[j + 1][2 * j + 1] += bdr_pred[j+1]
+
+        s[j+2][2*j+2] = s[j+1][2*j+1] + song_pred[j+2][blank]
+        opt[j+2][2*j+2] = 1  # 45 degree
+
+    for audio_pos in np.arange(2, audio_length):
+
+        for ch_pos in np.arange(1, 2*lyrics_length+1):
+
+            if ch_pos % 2 == 1 and (ch_pos+1)/2 >= audio_pos:
+                break
+            if ch_pos % 2 == 0 and ch_pos/2 + 1 >= audio_pos:
+                break
+
+            if ch_pos % 2 == 1: # ch
+                ch_idx = int((ch_pos-1)/2)
+                # cur ch -> ch
+                a = s[audio_pos-1][ch_pos] + song_pred[audio_pos][lyrics_int[ch_idx]]
+                # last ch -> ch
+                b = s[audio_pos-1][ch_pos-2] + song_pred[audio_pos][lyrics_int[ch_idx]]
+                # eps -> ch
+                c = s[audio_pos-1][ch_pos-1] + song_pred[audio_pos][lyrics_int[ch_idx]]
+                if a > b and a > c:
+                    s[audio_pos][ch_pos] = a
+                    opt[audio_pos][ch_pos] = 0
+                elif b >= a and b >= c:
+                    s[audio_pos][ch_pos] = b
+                    opt[audio_pos][ch_pos] = 2
+                else:
+                    s[audio_pos][ch_pos] = c
+                    opt[audio_pos][ch_pos] = 1
+
+                if ch_idx in line_start:
+                    s[audio_pos][ch_pos] += bdr_pred[audio_pos]
+
+            if ch_pos % 2 == 0: # eps
+                # cur ch -> ch
+                a = s[audio_pos-1][ch_pos] + song_pred[audio_pos][blank]
+                # eps -> ch
+                c = s[audio_pos-1][ch_pos-1] + song_pred[audio_pos][blank]
+                if a > c:
+                    s[audio_pos][ch_pos] = a
+                    opt[audio_pos][ch_pos] = 0
+                else:
+                    s[audio_pos][ch_pos] = c
+                    opt[audio_pos][ch_pos] = 1
+
+    score = s[audio_length-1][2*lyrics_length]
+
+    # retrive optimal path
+    path = []
+    x = audio_length-1
+    y = 2*lyrics_length
+    path.append([x, y])
+    while x > 0 or y > 0:
+        if opt[x][y] == 1:
+            x -= 1
+            y -= 1
+        elif opt[x][y] == 2:
+            x -= 1
+            y -= 2
+        else:
+            x -= 1
+        path.append([x, y])
+
+    path = list(reversed(path))
+    word_align = []
+    path_i = 0
+
+    word_i = 0
+    while word_i < len(idx):
+        # e.g. "happy day"
+        # find the first time "h" appears
+        if path[path_i][1] == 2*idx[word_i][0]+1:
+            st = path[path_i][0]
+            # find the first time " " appears after "h"
+            while  path_i < len(path)-1 and (path[path_i][1] != 2*idx[word_i][1]+1):
+                path_i += 1
+            ed = path[path_i][0]
+            # append
+            word_align.append([st, ed])
+            # move to next word
+            word_i += 1
+        else:
+            # move to next audio frame
+            path_i += 1
+
+    return word_align, score
+
+def phone2seq(text):
+    seq = []
+    for c in text:
+        if c in phone_dict:
+            idx = phone2int[c]
+        else:
+            # print(c) # unknown
+            idx = 40
+        seq.append(idx)
+    return np.array(seq)
+
+def ToolFreq2Midi(fInHz, fA4InHz=440):
+    '''
+    source: https://www.audiocontentanalysis.org/code/helper-functions/frequency-to-midi-pitch-conversion-2/
+    '''
+    def convert_freq2midi_scalar(f, fA4InHz):
+
+        if f <= 0:
+            return 0
+        else:
+            return (69 + 12 * np.log2(f / fA4InHz))
+
+    fInHz = np.asarray(fInHz)
+    if fInHz.ndim == 0:
+        return convert_freq2midi_scalar(fInHz, fA4InHz)
+
+    midi = np.zeros(fInHz.shape)
+    for k, f in enumerate(fInHz):
+        midi[k] = convert_freq2midi_scalar(f, fA4InHz)
+
+    return (midi)
+
+def notes_to_pc(notes, resolution, total_length):
+
+    pc = np.full(shape=(total_length,), fill_value=46, dtype=np.short)
+
+    for i in np.arange(len(notes[0])):
+        pitch = notes[0][i]
+        if pitch == -100:
+            pc[0:total_length] = pitch
+        else:
+            times = np.floor(notes[1][i] / resolution)
+            st = int(np.max([0, times[0]]))
+            ed = int(np.min([total_length, times[1]]))
+            pc[st:ed] = pitch
+
+    return pc
+
+def voc_to_contour(times, resolution, total_length, smoothing=False):
+
+    contour = np.full(shape=(total_length,), fill_value=0, dtype=np.short)
+
+    for i in np.arange(len(times)):
+        time = np.floor(times[i] / resolution)
+        st = int(np.max([0, time[0]]))
+        ed = int(np.min([total_length, time[1]]))
+        contour[st:ed] = 1
+
+        # TODO: add smoothing option
+        if smoothing:
+            pass
+
+    return contour
\ No newline at end of file
diff --git a/autosyl/LyricsAlignment/wrapper.py b/autosyl/LyricsAlignment/wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..825ec7c7794955ba01769c713137e7cb28206c02
--- /dev/null
+++ b/autosyl/LyricsAlignment/wrapper.py
@@ -0,0 +1,179 @@
+import warnings, librosa
+import numpy as np
+from time import time
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import autosyl.LyricsAlignment.utils as utils
+from autosyl.LyricsAlignment.model import train_audio_transforms, AcousticModel, BoundaryDetection
+
+np.random.seed(7)
+
+def preprocess_from_file(audio_file, lyrics_file, word_file=None, language="jp"):
+    y, sr = preprocess_audio(audio_file)
+
+    words, lyrics_p, idx_word_p, idx_line_p = preprocess_lyrics(lyrics_file, word_file, language=language)
+
+    return y, words, lyrics_p, idx_word_p, idx_line_p
+
+def align(audio, words, lyrics_p, idx_word_p, idx_line_p, method="Baseline", cuda=True, checkpoint_folder="."):
+
+    # start timer
+    t = time()
+
+    # constants
+    resolution = 256 / 22050 * 3
+    alpha = 0.8
+
+    # decode method
+    if "BDR" in method:
+        model_type = method[:-4]
+        bdr_flag = True
+    else:
+        model_type = method
+        bdr_flag = False
+    print("Model: {} BDR?: {}".format(model_type, bdr_flag))
+
+    # prepare acoustic model params
+    if model_type == "Baseline":
+        n_class = 41
+    elif model_type == "MTL":
+        n_class = (41, 47)
+    else:
+        ValueError("Invalid model type.")
+
+    hparams = {
+        "n_cnn_layers": 1,
+        "n_rnn_layers": 3,
+        "rnn_dim": 256,
+        "n_class": n_class,
+        "n_feats": 32,
+        "stride": 1,
+        "dropout": 0.1
+    }
+
+    device = 'cuda' if (cuda and torch.cuda.is_available()) else 'cpu'
+
+    ac_model = AcousticModel(
+        hparams['n_cnn_layers'], hparams['rnn_dim'], hparams['n_class'], \
+        hparams['n_feats'], hparams['stride'], hparams['dropout']
+    ).to(device)
+
+    print("Loading acoustic model from checkpoint...")
+    state = utils.load_model(ac_model, "{}/checkpoint_{}".format(checkpoint_folder, model_type), cuda=(device=="gpu"))
+    ac_model.eval()
+
+    print("Computing phoneme posteriorgram...")
+
+    # reshape input, prepare mel
+    x = audio.reshape(1, 1, -1)
+    x = utils.move_data_to_device(x, device)
+    x = x.squeeze(0)
+    x = x.squeeze(1)
+    x = train_audio_transforms.to(device)(x)
+    x = nn.utils.rnn.pad_sequence(x, batch_first=True).unsqueeze(1)
+
+    # predict
+    all_outputs = ac_model(x)
+    if model_type == "MTL":
+        all_outputs = torch.sum(all_outputs, dim=3)
+
+    all_outputs = F.log_softmax(all_outputs, dim=2)
+
+    batch_num, output_length, num_classes = all_outputs.shape
+    song_pred = all_outputs.data.cpu().numpy().reshape(-1, num_classes)  # total_length, num_classes
+    total_length = int(audio.shape[1] / 22050 // resolution)
+    song_pred = song_pred[:total_length, :]
+
+    # smoothing
+    P_noise = np.random.uniform(low=1e-11, high=1e-10, size=song_pred.shape)
+    song_pred = np.log(np.exp(song_pred) + P_noise)
+
+    if bdr_flag:
+        # boundary model: fixed
+        bdr_hparams = {
+            "n_cnn_layers": 1,
+            "rnn_dim": 32,  # a smaller rnn dim than acoustic model
+            "n_class": 1,  # binary classification
+            "n_feats": 32,
+            "stride": 1,
+            "dropout": 0.1,
+        }
+
+        bdr_model = BoundaryDetection(
+            bdr_hparams['n_cnn_layers'], bdr_hparams['rnn_dim'], bdr_hparams['n_class'],
+            bdr_hparams['n_feats'], bdr_hparams['stride'], bdr_hparams['dropout']
+        ).to(device)
+        print("Loading BDR model from checkpoint...")
+        state = utils.load_model(bdr_model, "{}/checkpoint_BDR".format(checkpoint_folder), cuda=(device == "gpu"))
+        bdr_model.eval()
+
+        print("Computing boundary probability curve...")
+        # get boundary prob curve
+        bdr_outputs = bdr_model(x).data.cpu().numpy().reshape(-1)
+        # apply log
+        bdr_outputs = np.log(bdr_outputs) * alpha
+
+        line_start = [d[0] for d in idx_line_p]
+
+        # start alignment
+        print("Aligning...It might take a few minutes...")
+        word_align, score = utils.alignment_bdr(song_pred, lyrics_p, idx_word_p, bdr_outputs, line_start)
+    else:
+        # start alignment
+        print("Aligning...It might take a few minutes...")
+        word_align, score = utils.alignment(song_pred, lyrics_p, idx_word_p)
+
+    t = time() - t
+    print("Alignment Score:\t{}\tTime:\t{}".format(score, t))
+
+    resolution = 25600 / 22050 * 3
+    word_align = [[round(word[0] * resolution), round(word[1] * resolution)] for word in word_align]
+
+    return word_align, words
+
+def preprocess_audio(audio_file, sr=22050):
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore")
+        y, curr_sr = librosa.load(audio_file, sr=sr, mono=True, res_type='kaiser_fast')
+
+    if len(y.shape) == 1:
+        y = y[np.newaxis, :] # (channel, sample)
+
+    return y, curr_sr
+
+def preprocess_lyrics(lyrics_lines, word_file=None, language="jp"):
+    from string import ascii_lowercase
+    d = {ascii_lowercase[i]: i for i in range(26)}
+    d["'"] = 26
+    d[" "] = 27
+    d["~"] = 28
+
+    # process raw
+    #with open(lyrics_file, 'r') as f:
+    #    raw_lines = f.read().splitlines()
+    raw_lines = lyrics_lines
+
+    raw_lines = ["".join([c for c in line.lower() if c in d.keys()]).strip() for line in raw_lines]
+    raw_lines = [" ".join(line.split()) for line in raw_lines if len(line) > 0]
+    # concat
+    full_lyrics = " ".join(raw_lines)
+
+    if word_file:
+        with open(word_file) as f:
+            words_lines = f.read().splitlines()
+    else:
+        words_lines = full_lyrics.split()
+
+    lyrics_p, words_p, idx_word_p, idx_line_p = utils.gen_phone_gt(words_lines, raw_lines, language=language)
+
+    return words_lines, lyrics_p, idx_word_p, idx_line_p
+
+def write_csv(pred_file, word_align, words):
+    resolution = 256 / 22050 * 3
+
+    with open(pred_file, 'w') as f:
+        for j in range(len(word_align)):
+            word_time = word_align[j]
+            f.write("{},{},{}\n".format(word_time[0] * resolution, word_time[1] * resolution, words[j]))
diff --git a/assUtils.py b/autosyl/assUtils.py
similarity index 67%
rename from assUtils.py
rename to autosyl/assUtils.py
index 7b5fd284d7923cd7537cf24aad9062dc6f34df30..cefbfa17909b1e14a7b5424a384ba942be9b9c90 100644
--- a/assUtils.py
+++ b/autosyl/assUtils.py
@@ -25,21 +25,44 @@ def dateToTime(date):
 
 def getSyls(ass_file):
     SYLS = []
+    META = []
     with open(ass_file, 'r') as f:
         CONTENT = f.read()
-        LINES_KARA = re.compile(r"Comment:.*(\d+:\d{2}:\d{2}.\d{2}),(\d+:\d{2}:\d{2}.\d{2}),.*,karaoke,(.*)\n");
+        LINES_KARA = re.compile(r"Comment:.*(\d+:\d{2}:\d{2}.\d{2}),(\d+:\d{2}:\d{2}.\d{2}),([^,]*),([^,]*),(\d+),(\d+),(\d+),karaoke,(.*)\n");
         RGX_TAGS = re.compile(r"\{\\k(\d+)\}([^\{\n\r]*)")
         for line in LINES_KARA.findall(CONTENT):
             syl_line = []
             lastTime = dateToTime(line[0])
-            for couple in RGX_TAGS.findall(line[2]):
-                syl_line.append([lastTime, couple[1], int(couple[0])])
+            for couple in RGX_TAGS.findall(line[7]):
+                if(couple[1] != '' and not couple[1].isspace()):
+                    syl_line.append([lastTime, couple[1], int(couple[0])])
                 lastTime += int(couple[0])
             syl_line.append([lastTime, '', 0])
             SYLS.append(syl_line)
-    return SYLS
-
-
+            line_meta = {}
+            line_meta['stylename'] = line[2]
+            line_meta['actor'] = line[3]
+            line_meta['margin_l'] = int(line[4])
+            line_meta['margin_r'] = int(line[5])
+            line_meta['margin_v'] = int(line[6])
+            META.append(line_meta)
+    return SYLS, META
+
+
+def getHeader(ass_file):
+    HEADER = ""
+    with open(ass_file, 'r') as f:
+        events_section = False
+        for line in f.readlines():
+            if not events_section:
+                HEADER += line
+                if re.match("^\[Events\].*", line):
+                    events_section = True
+            else:
+                event_regex = "(?:Format:.*)|(?:^Comment:.*(\d+:\d{2}:\d{2}.\d{2}),(\d+:\d{2}:\d{2}.\d{2}),[^,]*,((?!karaoke).)*,((?!karaoke).)*\n$)"
+                if re.match(event_regex, line):
+                    HEADER += line
+    return HEADER
 
 
 
@@ -54,8 +77,9 @@ class AssWriter:
     def closeAss(self):
         self.file.close()
     
-    def writeHeader(self):
-        header = '''[Script Info]
+    def writeHeader(self, header=None):
+        if not header:
+            header = '''[Script Info]
 Title: Default Aegisub file
 ScriptType: v4.00+
 WrapStyle: 0
@@ -78,7 +102,7 @@ Comment: 0,0:00:03.68,0:00:05.68,Default,,0,0,0,code syl all,bord = line.stylere
 Comment: 0,0:00:05.68,0:00:05.68,Default,,0,0,0,,
 Comment: 0,0:00:05.68,0:00:05.68,Default,,0,0,0,,
 Comment: 10,0:00:05.68,0:00:10.43,Default,,0,0,0,template char noblank all,!retime("start2syl",line_in_retime,syl_shift)!{\\fad(!fade_in_dur!,0)\\bord!bord!\shad!shadow!\\blur!blur!\\3c!line.styleref.color3!\c!line.styleref.color3!\pos($x,$y)}
-Comment: 20,0:00:05.68,0:00:10.43,Default,,0,0,0,template char noblank all,!retime("start2syl",line_in_retime,syl_shift)!{\fad(200,0)\\bord0\shad0\c!line.styleref.color2!\pos($x,$y)}
+Comment: 20,0:00:05.68,0:00:10.43,Default,,0,0,0,template char noblank all,!retime("start2syl",line_in_retime,syl_shift)!{\\fad(200,0)\\bord0\shad0\c!line.styleref.color2!\pos($x,$y)}
 Comment: 10,0:00:05.68,0:00:10.43,Default,,0,0,0,template char noblank all,!retime("syl",syl_shift,syl_shift-$dur/syl_in_speed)!{\\3c!line.styleref.color3!\c!line.styleref.color3!\\bord!bord!\shad!shadow!\\blur!blur!\move($x,$y,$x,!$y-($height/8)!,0,!($dur > syl_in_speed * 400 and 400 or $dur/syl_in_speed)!)}
 Comment: 40,0:00:05.68,0:00:10.43,Default,,0,0,0,template char noblank all,!retime("syl",syl_shift,syl_shift-$dur/syl_in_speed)!{\\bord0\shad0\c!line.styleref.color2!\\t(0,!($dur > syl_in_speed * 400 and 400 or $dur/syl_in_speed)!,\c!line.styleref.color1!)\move($x,$y,$x,!$y-($height/8)!,0,!($dur > syl_in_speed * 400 and 400 or $dur/syl_in_speed)!)}
 Comment: 10,0:00:05.68,0:00:10.43,Default,,0,0,0,template char noblank all,!retime("postsyl",syl_shift-$dur/syl_in_speed,($lstart + $send + ($dur/12)) < ($lend + 300) and $dur/12 or $lend - $lstart - $send + 300)!{\\3c!line.styleref.color3!\c!line.styleref.color3!\\bord!bord!\shad!shadow!\\blur!blur!\move($x,!$y-($height/8)!,$x,!$y!,0,!$dur/syl_out_speed!)!($lstart + $end + $dur/12) < ($lend + 300) and "" or "\\fad(0,150)"!}
@@ -90,16 +114,18 @@ Comment: 0,0:00:05.68,0:00:05.68,Default,,0,0,0,,
 '''
         self.file.write(header)
 
-    def writeSyls(self, syl_timings):
-        bottom = False
-        for syl_line in syl_timings:
-            start_time = timeToDate(syl_line[0][0])
-            end_time = timeToDate(syl_line[-1][0])
-            v_margin = (150 if bottom else 50)
-            line = f'Dialogue: 0,{start_time},{end_time},Default,,0,0,{v_margin:d},,'
-            for i in range(len(syl_line) - 1):
-                syl_dur = round((syl_line[i+1][0] - syl_line[i][0]) * 100)
-                line += f'{{\k{syl_dur:d}}}{syl_line[i][1]:s}'
+    def writeSyls(self, syl_timings, line_meta):
+        for line_index in range(len(syl_timings)):
+            start_time = timeToDate(syl_timings[line_index][0][0])
+            end_time = timeToDate(syl_timings[line_index][-1][0])
+            margin_v = line_meta[line_index]['margin_v']
+            margin_l = line_meta[line_index]['margin_l']
+            margin_r = line_meta[line_index]['margin_r']
+            style = line_meta[line_index]['stylename']
+            actor = line_meta[line_index]['actor']
+            line = f'Dialogue: 0,{start_time},{end_time},{style},{actor},{margin_l:d},{margin_r:d},{margin_v:d},,'
+            for i in range(len(syl_timings[line_index]) - 1):
+                syl_dur = round((syl_timings[line_index][i+1][0] - syl_timings[line_index][i][0]) * 100)
+                line += f'{{\k{syl_dur:d}}}{syl_timings[line_index][i][1]:s}'
             line += '\n'
-            self.file.write(line)
-            bottom = not bottom
\ No newline at end of file
+            self.file.write(line)
\ No newline at end of file
diff --git a/autosyl/segment.py b/autosyl/segment.py
index 5f4d785b7767ce5a92b93f048c3238717e7be823..adacae081bfd39c17847a526334c05587bd11ae9 100644
--- a/autosyl/segment.py
+++ b/autosyl/segment.py
@@ -6,91 +6,53 @@ import matplotlib.pyplot as plt
 import scipy.signal as sg
 import parselmouth
 
+from autosyl.assUtils import getSyls, timeToDate, dateToTime 
+from autosyl.LyricsAlignment.wrapper import align, preprocess_from_file
 
-def segment(songfile, reference_syls=None, syls_per_line=10, last_syl_dur=500):
 
-    delay = -4
-    backtrack = False
-
-    cnn = madmom.features.onsets.CNNOnsetProcessor()
-    spectral = madmom.features.onsets.SpectralOnsetProcessor('modified_kullback_leibler')
-
-    sig = madmom.audio.signal.Signal(songfile, num_channels=1)
-    parsel = parselmouth.Sound(sig)
-
-    spec = madmom.audio.spectrogram.Spectrogram(sig)
-    filt_spec = madmom.audio.spectrogram.FilteredSpectrogram(spec, filterbank=madmom.audio.filters.LogFilterbank, num_bands=24)
-    log_spec = madmom.audio.spectrogram.LogarithmicSpectrogram(filt_spec, add=1)
-
-    magnitude = np.max(log_spec[:,:100], axis=1)
 
-    cnn_function = cnn(sig)
-    spectral_function = spectral(sig)
-    spectral_function = spectral_function/(spectral_function.max())
-    
-    #activation_function = 0.5*cnn_function + 0.5*spectral_function
-    activation_function = (2 * cnn_function * spectral_function)/(cnn_function + spectral_function)
-    #activation_function = np.where(spectral_function > 0.14, cnn_function, 0)
-    #onsets = proc(activation_function)
-    
-    if reference_syls:
-        activation_threshold = 0.1
-    else:
-        activation_threshold = 0.2
+def segment(songfile, reference_syls=None, syls_per_line=10, last_syl_dur=500, verbose=False, language="jp"):
 
+    delay = -4
+    backtrack = False
 
-    activation_smoothed = madmom.audio.signal.smooth(activation_function, 20)
-    cnn_smoothed = madmom.audio.signal.smooth(cnn_function, 20)
-    onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=activation_threshold, smooth=0)
-    #onsets = np.array([o for o in onsets if cnn_smoothed[o] > activation_threshold])
+    print(reference_syls)
 
-    pitch = parsel.to_pitch()
-    pitch_values = pitch.selected_array['frequency']
+    audio_file = songfile                      # pre-computed source-separated vocals; These models do not work with mixture input.
+    word_file = None                           # example: jamendolyrics/lyrics/*.words.txt"; Set to None if you don't have it
+    method = "MTL_BDR"                             # "Baseline", "MTL", "Baseline_BDR", "MTL_BDR"
+    cuda=False                                 # set True if you have access to a GPU
+    checkpoint_folder = "./autosyl/LyricsAlignment/checkpoints"
+    language = language
 
-    pad_before = round(pitch.xs()[0]*100)
-    pad_after = len(magnitude) - len(pitch_values) - pad_before
 
-    pitch_values = np.pad(pitch_values, (pad_before, pad_after), 'constant', constant_values=(0,0))
+    lyrics_lines = [" ".join([syl[1] for syl in line]) for line in reference_syls]
+    #print(lyrics_lines)
 
-    mask_function = magnitude * pitch_values
-    mask_function = mask_function/np.max(mask_function)
-    mask_threshold = 0.15
-    mask_window = [1,6]
-    invalid_onsets_idx = []
-    
-    for i in range(len(onsets)):
-        if np.max(mask_function[onsets[i]+mask_window[0]:onsets[i]+mask_window[1]]) < mask_threshold:
-            invalid_onsets_idx.append(i)
-    
-    onsets = np.delete(onsets, invalid_onsets_idx)
 
+    if verbose:
+        print("Preprocessing audio and lyrics...")
+    # load audio and lyrics
+    # words:        a list of words
+    # lyrics_p:     phoneme sequence of the target lyrics
+    # idx_word_p:   indices of word start in lyrics_p
+    # idx_line_p:   indices of line start in lyrics_p
+    audio, words, lyrics_p, idx_word_p, idx_line_p = preprocess_from_file(audio_file, lyrics_lines, word_file, language)
+    if verbose:
+        print(lyrics_p)
 
-    if reference_syls:
-        filtered_onsets = []
-        line_index = 0
-        for line in reference_syls:
-            line_index += 1
-            syl_number = len(line) - 1
-            line_onsets = [o for o in onsets if (o >= line[0][0] and o <= line[-1][0])]
-            line_onsets.sort(reverse=True, key=(lambda x: activation_smoothed[x]))
-            missing_syls = 0
-            if syl_number > len(line_onsets):
-                print("WARNING : failed to detect enough onsets in line %d (%d, %d)" % (line_index, line[0][0], line[-1][0]))
-                missing_syls = syl_number - len(line_onsets)
-            filtered_onsets += line_onsets[0:syl_number]
-            filtered_onsets += [line[-1][0] for i in range(missing_syls)] # If missing some syllables, pad with 0-length syls
-        
-        onsets = np.array(sorted(filtered_onsets))
+    if verbose:
+        print("Retrieving syls from lyrics...")
+    # compute alignment
+    # word_align:   a list of frame indices aligned to each word
+    # words:        a list of words
+    word_align, words = align(audio, words, lyrics_p, idx_word_p, idx_line_p, method=method, cuda=cuda, checkpoint_folder=checkpoint_folder)
 
 
-    if backtrack:
-        # Backtrack onsets to closest earlier local minimum
-        backtrack_max_frames = 50
-        for i in range(len(onsets)):
-            initial_onset = onsets[i]
-            while(activation_smoothed[onsets[i] - 1] < activation_smoothed[onsets[i]] and onsets[i] > initial_onset - backtrack_max_frames):
-                onsets[i] -= 1
+    words_onsets = np.array([word_align[i][0] for i in range(len(word_align))])
+    print(words_onsets)
 
+    onsets = words_onsets
     onsets = (onsets + delay)/100
     #print(onsets)
 
@@ -100,7 +62,7 @@ def segment(songfile, reference_syls=None, syls_per_line=10, last_syl_dur=500):
         for line in reference_syls:
             #print(onset_index, " : ", line)
             l = [[onsets[onset_index + i], line[i][1]] for i in range(len(line)-1)]
-            l.append([line[-1][0]/100, ''])
+            l.append([word_align[onset_index + (len(line) - 2)][1]/100, ''])
             syls.append(l)
             onset_index += (len(line) - 1)
     else:
@@ -121,142 +83,3 @@ def segment(songfile, reference_syls=None, syls_per_line=10, last_syl_dur=500):
     return syls
 
 
-
-if __name__ == "__main__":
-
-    def dateToTime(date):
-        """
-        The `date` should be in the following format: H:MM:SS.cs
-        """
-        hourInMinuts = int(date[0:1]) * 60
-        minutsInSeconds = (int(date[2:4]) + hourInMinuts) * 60
-        secondsInCentiseconds = (int(date[5:7]) + minutsInSeconds) * 100
-        return int(date[8:10]) + secondsInCentiseconds
-
-
-    def getSyls(ass_file):
-        SYLS = []
-        with open(ass_file, 'r') as f:
-            CONTENT = f.read()
-            LINES_KARA = re.compile(r"Comment:.*(\d+:\d{2}:\d{2}.\d{2}),(\d+:\d{2}:\d{2}.\d{2}),.*,karaoke,(.*)\n");
-            RGX_TAGS = re.compile(r"\{\\k(\d+)\}([^\{\n\r]*)")
-            for line in LINES_KARA.findall(CONTENT):
-                syl_line = []
-                lastTime = dateToTime(line[0])
-                for couple in RGX_TAGS.findall(line[2]):
-                    syl_line.append((lastTime, couple[1], int(couple[0])))
-                    lastTime += int(couple[0])
-                syl_line.append([lastTime, '', 0])
-                SYLS.append(syl_line)
-        return SYLS
-
-
-
-
-    songfile = sys.argv[1]
-    if(len(sys.argv) >= 3):
-        reference_syls = getSyls(sys.argv[2])
-    else:
-        reference_syls = None
-    
-    #print(reference_syls)
-
-    backtrack = False
-
-    cnn = madmom.features.onsets.CNNOnsetProcessor()
-    spectral = madmom.features.onsets.SpectralOnsetProcessor('modified_kullback_leibler')
-
-    sig = madmom.audio.signal.Signal(songfile, num_channels=1)
-    parsel = parselmouth.Sound(sig)
-
-    spec = madmom.audio.spectrogram.Spectrogram(sig)
-    filt_spec = madmom.audio.spectrogram.FilteredSpectrogram(spec, filterbank=madmom.audio.filters.LogFilterbank, num_bands=24)
-    log_spec = madmom.audio.spectrogram.LogarithmicSpectrogram(filt_spec, add=1)
-
-    magnitude = np.max(log_spec[:,:100], axis=1)
-
-    cnn_function = cnn(sig)
-    spectral_function = spectral(sig)
-    spectral_function = spectral_function/(spectral_function.max())
-    
-    #activation_function = 0.5*cnn_function + 0.5*spectral_function
-    activation_function = (2 * cnn_function * spectral_function)/(cnn_function + spectral_function)
-    #activation_function = np.where(spectral_function > 0.14, cnn_function, 0)
-    #onsets = proc(activation_function)
-    
-
-    if reference_syls:
-        activation_threshold = 0.1
-    else:
-        activation_threshold = 0.2
-
-    activation_smoothed = madmom.audio.signal.smooth(activation_function, 20)
-    cnn_smoothed = madmom.audio.signal.smooth(cnn_function, 20)
-    onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=activation_threshold, smooth=0)
-    #onsets = np.array([o for o in onsets if cnn_smoothed[o] > 0.1])
-
-    pitch = parsel.to_pitch()
-    pitch_values = pitch.selected_array['frequency']
-
-    pad_before = round(pitch.xs()[0]*100)
-    pad_after = len(magnitude) - len(pitch_values) - pad_before
-
-    pitch_values = np.pad(pitch_values, (pad_before, pad_after), 'constant', constant_values=(0,0))
-
-    mask_function = magnitude * pitch_values
-    mask_function = mask_function/np.max(mask_function)
-    mask_threshold = 0.15
-    mask_window = [1,6]
-    invalid_onsets_idx = []
-    
-    for i in range(len(onsets)):
-        if np.max(mask_function[onsets[i]+mask_window[0]:onsets[i]+mask_window[1]]) < mask_threshold:
-            invalid_onsets_idx.append(i)
-    
-    onsets = np.delete(onsets, invalid_onsets_idx)
-
-
-    if reference_syls:
-        filtered_onsets = []
-        line_index = 0
-        for line in reference_syls:
-            line_index += 1
-            syl_number = len(line) - 1
-            line_onsets = [o for o in onsets if (o >= line[0][0] and o <= line[-1][0])]
-            line_onsets.sort(reverse=True, key=(lambda x: activation_smoothed[x]))
-            if syl_number > len(line_onsets):
-                print("WARNING : failed to detect enough onsets in line %d (%d, %d)" % (line_index, line[0][0], line[-1][0]))
-            filtered_onsets += line_onsets[0:syl_number]
-        
-        onsets = np.array(sorted(filtered_onsets))
-    
-    # Backtrack onsets to closest earlier local minimum
-    if backtrack:
-        backtrack_max_frames = 50
-        for i in range(len(onsets)):
-            initial_onset = onsets[i]
-            while(activation_smoothed[onsets[i] - 1] < activation_smoothed[onsets[i]] and onsets[i] > initial_onset - backtrack_max_frames):
-                onsets[i] -= 1
-
-    print(onsets/100)
-
-    if reference_syls:
-        reference_onsets = [syl[0]+8 for line in reference_syls for syl in line[:-1]]
-
-    fig, axs = plt.subplots(nrows=2, sharex=True)
-    axs[0].imshow(log_spec.T, origin='lower', aspect='auto')
-    if reference_syls:
-        axs[0].vlines(reference_onsets, 0, 140, colors='red')
-    axs[0].plot((pitch_values/np.max(pitch_values))*140, color='yellow')
-    axs[1].plot(mask_function)
-    #axs[1].plot(cnn_smoothed)
-    #axs[1].plot(spectral_function, color='green')
-    axs[1].plot(activation_smoothed, color='orange')
-    axs[1].vlines(onsets, 0, 2, colors='red')
-    axs[1].hlines([max(mask_threshold, 0), activation_threshold], 0, onsets[-1]+100, colors='black')
-
-    #bins = np.arange(0, 1, 0.02)
-    #hist, hist_axs = plt.subplots(nrows=1)
-    #hist_axs.hist(mask_function, bins=bins)
-
-    plt.show()
diff --git a/g2p/mappings/langs/rji/config.yaml b/g2p/mappings/langs/rji/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..80e5d9854978e2781c709a5112f25d74aceabd97
--- /dev/null
+++ b/g2p/mappings/langs/rji/config.yaml
@@ -0,0 +1,16 @@
+<<: &shared
+  - language_name: Romaji
+mappings:
+  - display_name: Romaji (Hepburn) to English Arpabet
+    in_lang: rji
+    out_lang: rji-eng-arpa
+    authors:
+      - Loïc Allègre
+    type: mapping
+    mapping: romaji_to_eng-arpa.csv
+    abbreviations: rji_abbs.csv
+    rule_ordering: as-written
+    case_sensitive: false
+    norm_form: 'NFC'
+    prevent_feeding: true
+    <<: *shared
\ No newline at end of file
diff --git a/g2p/mappings/langs/rji/rji_abbs.csv b/g2p/mappings/langs/rji/rji_abbs.csv
new file mode 100644
index 0000000000000000000000000000000000000000..241e2f94fe457b167c07bb64bb8812cbc6e0d992
--- /dev/null
+++ b/g2p/mappings/langs/rji/rji_abbs.csv
@@ -0,0 +1,3 @@
+VOWEL,a,e,i,o,u
+EI_VOW,e,i
+AOU_VOW,a,o,u
\ No newline at end of file
diff --git a/g2p/mappings/langs/rji/romaji_to_eng-arpa.csv b/g2p/mappings/langs/rji/romaji_to_eng-arpa.csv
new file mode 100644
index 0000000000000000000000000000000000000000..f265167ec999c90900477e2a215ed78d50c17a11
--- /dev/null
+++ b/g2p/mappings/langs/rji/romaji_to_eng-arpa.csv
@@ -0,0 +1,34 @@
+ch,CH ,,,true
+ou,OW ,,,true
+sh,SH ,,,true
+dj,JH ,,,true
+dz,Z ,,,true
+a,AA ,,,false
+e,EH ,,,false
+i,IY ,,,false
+o,OW ,,,false
+u,UW ,,,false
+k,K ,,y,false
+k,K ,,VOWEL,false
+k,KUW ,,CONS,false
+k,K ,,,false
+s,S ,,VOWEL,false
+s,SUW ,,CONS,false
+s,S ,,,false
+t,T ,,,false
+n,N ,,,false
+h,H ,,,false
+m,M ,,,false
+y,Y ,,,false
+r,L ,,,false
+w,W ,,,false
+g,G ,,,false
+z,Z ,,,false
+d,D ,,,false
+b,B ,,,false
+p,P ,,,false
+f,F ,,,false
+v,V ,,,false
+j,JH ,,,false
+q,K ,,,false
+l,L ,,,false
\ No newline at end of file
diff --git a/plot_syls.py b/plot_syls.py
new file mode 100644
index 0000000000000000000000000000000000000000..e86960430761fc37dcc54ebfc60dbcc1a1361da1
--- /dev/null
+++ b/plot_syls.py
@@ -0,0 +1,191 @@
+import madmom
+import numpy as np
+import sys
+import re
+import matplotlib.pyplot as plt
+import scipy.signal as sg
+import parselmouth
+import argparse
+
+from autosyl.assUtils import getSyls, timeToDate, dateToTime 
+from autosyl.LyricsAlignment.wrapper import align, preprocess_from_file
+
+
+##############################################################################
+#
+# This is a test script to visualize extracted onsets and other audio features
+# It is mainly intended for development/debug
+#
+# If you just want to detect the syllables, use autokara.py instead
+#
+##############################################################################
+
+
+parser = argparse.ArgumentParser(description='AutoKara - Automatic karaoke timing tool')
+parser.add_argument("vocals_file", type=str, help="The audio file to time")
+parser.add_argument("ass_file", type=str, help="The ASS file with lyrics to time")
+
+args = parser.parse_args()
+
+
+songfile = args.vocals_file
+reference_syls, line_meta = getSyls(sys.argv[2])
+
+
+print(reference_syls)
+
+backtrack = False
+
+
+
+
+audio_file = songfile                      # pre-computed source-separated vocals; These models do not work with mixture input.
+word_file = None                           # example: jamendolyrics/lyrics/*.words.txt"; Set to None if you don't have it
+method = "MTL_BDR"                             # "Baseline", "MTL", "Baseline_BDR", "MTL_BDR"
+cuda=True                                 # set True if you have access to a GPU
+checkpoint_folder = "./autosyl/LyricsAlignment/checkpoints"
+
+pred_file = "./MTL.csv"                    # saved alignment results, "(float) start_time, (float) end_time, (string) word"
+
+
+lyrics_lines = [" ".join([syl[1] for syl in line]) for line in reference_syls]
+#print(lyrics_lines)
+
+
+# load audio and lyrics
+# words:        a list of words
+# lyrics_p:     phoneme sequence of the target lyrics
+# idx_word_p:   indices of word start in lyrics_p
+# idx_line_p:   indices of line start in lyrics_p
+audio, words, lyrics_p, idx_word_p, idx_line_p = preprocess_from_file(audio_file, lyrics_lines, word_file)
+
+# compute alignment
+# word_align:   a list of frame indices aligned to each word
+# words:        a list of words
+word_align, words = align(audio, words, lyrics_p, idx_word_p, idx_line_p, method=method, cuda=False, checkpoint_folder=checkpoint_folder)
+
+
+print([[word_align[i][0], word_align[i][1], words[i]] for i in range(len(word_align))])
+words_onsets = np.array([word_align[i][0] for i in range(len(word_align))])
+
+
+cnn = madmom.features.onsets.CNNOnsetProcessor()
+spectral = madmom.features.onsets.SpectralOnsetProcessor('modified_kullback_leibler')
+
+sig = madmom.audio.signal.Signal(songfile, num_channels=1)
+parsel = parselmouth.Sound(sig)
+
+spec = madmom.audio.spectrogram.Spectrogram(sig)
+filt_spec = madmom.audio.spectrogram.FilteredSpectrogram(spec, filterbank=madmom.audio.filters.LogFilterbank, num_bands=24)
+log_spec = madmom.audio.spectrogram.LogarithmicSpectrogram(filt_spec, add=1)
+
+magnitude = np.max(log_spec[:,:100], axis=1)
+
+cnn_function = cnn(sig)
+spectral_function = spectral(sig)
+spectral_function = spectral_function/(spectral_function.max())
+
+#activation_function = 0.5*cnn_function + 0.5*spectral_function
+activation_function = (2 * cnn_function * spectral_function)/(cnn_function + spectral_function)
+#activation_function = np.where(spectral_function > 0.14, cnn_function, 0)
+#onsets = proc(activation_function)
+
+
+if reference_syls:
+    activation_threshold = 0.1
+else:
+    activation_threshold = 0.2
+
+activation_smoothed = madmom.audio.signal.smooth(activation_function, 20)
+cnn_smoothed = madmom.audio.signal.smooth(cnn_function, 20)
+onsets = madmom.features.onsets.peak_picking(activation_smoothed, threshold=activation_threshold, smooth=0)
+#onsets = np.array([o for o in onsets if cnn_smoothed[o] > 0.1])
+
+pitch = parsel.to_pitch()
+pitch_values = pitch.selected_array['frequency']
+
+pad_before = round(pitch.xs()[0]*100)
+pad_after = len(magnitude) - len(pitch_values) - pad_before
+
+pitch_values = np.pad(pitch_values, (pad_before, pad_after), 'constant', constant_values=(0,0))
+
+mask_function = magnitude * pitch_values
+mask_function = mask_function/np.max(mask_function)
+mask_threshold = 0.15
+mask_window = [1,6]
+invalid_onsets_idx = []
+
+for i in range(len(onsets)):
+    if np.max(mask_function[onsets[i]+mask_window[0]:onsets[i]+mask_window[1]]) < mask_threshold:
+        invalid_onsets_idx.append(i)
+
+onsets = np.delete(onsets, invalid_onsets_idx)
+
+
+
+if reference_syls:
+    filtered_onsets = []
+    line_index = 0
+    for line in reference_syls:
+        line_index += 1
+        syl_number = len(line) - 1
+        line_onsets = [o for o in onsets if (o >= line[0][0] and o <= line[-1][0])]
+        line_onsets.sort(reverse=True, key=(lambda x: activation_smoothed[x]))
+        if syl_number > len(line_onsets):
+            print("WARNING : failed to detect enough onsets in line %d (%d, %d)" % (line_index, line[0][0], line[-1][0]))
+        filtered_onsets += line_onsets[0:syl_number]
+    
+    onsets = np.array(sorted(filtered_onsets))
+
+
+""" 
+    if word_index > 0:
+        word_start = max(word_align[word_index][0] - 5, line[0][0], previous_onset+1)
+    else:
+        word_start = line[0][0]
+    if word_index < len(words) - 1 and syl_index < len(line) - 2:
+        word_end = min(line[-1][0], word_align[word_index + 1][0] - 5)
+    else:
+        word_end = line[-1][0]
+
+    word_onsets = [o for o in onsets if (o >= word_start and o <= word_end)]
+    word_onsets.sort(reverse=True, key=(lambda x: activation_smoothed[x]))
+    if word_syl_count > len(word_onsets):
+        print("WARNING : failed to detect enough onsets in word %s (%d, %d)" % (word_tmp, word_start, word_end))
+    filtered_onsets += word_onsets[0:word_syl_count]
+    print(word_onsets[0:word_syl_count])
+    previous_onset = max(word_onsets[0:word_syl_count] + [0])
+"""
+
+# Backtrack onsets to closest earlier local minimum
+if backtrack:
+    backtrack_max_frames = 50
+    for i in range(len(onsets)):
+        initial_onset = onsets[i]
+        while(activation_smoothed[onsets[i] - 1] < activation_smoothed[onsets[i]] and onsets[i] > initial_onset - backtrack_max_frames):
+            onsets[i] -= 1
+
+#print(onsets/100)
+print(words_onsets/100)
+
+if reference_syls:
+    reference_onsets = [syl[0]+8 for line in reference_syls for syl in line[:-1]]
+
+fig, axs = plt.subplots(nrows=2, sharex=True)
+axs[0].imshow(log_spec.T, origin='lower', aspect='auto')
+if reference_syls:
+    axs[0].vlines(reference_onsets, 0, 140, colors='red')
+axs[0].plot((pitch_values/np.max(pitch_values))*140, color='yellow')
+axs[1].plot(mask_function)
+#axs[1].plot(cnn_smoothed)
+#axs[1].plot(spectral_function, color='green')
+axs[1].plot(activation_smoothed, color='orange')
+axs[1].vlines(onsets, 0, 2, colors='red')
+axs[1].vlines(words_onsets, 0, 3, colors='m')
+axs[1].hlines([max(mask_threshold, 0), activation_threshold], 0, onsets[-1]+100, colors='black')
+
+#bins = np.arange(0, 1, 0.02)
+#hist, hist_axs = plt.subplots(nrows=1)
+#hist_axs.hist(mask_function, bins=bins)
+
+plt.show()
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 037cec811b5956b459fe862eae3b0b82dfc61d45..04299146d72785e9cf25e747f7bcea65dd11d05c 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -10,4 +10,16 @@ scipy
 cython
 mido
 git+https://github.com/CPJKU/madmom.git
-praat-parselmouth
\ No newline at end of file
+praat-parselmouth
+future
+musdb
+museval
+h5py
+tqdm
+torch>=1.8.0
+torchaudio
+tensorboard
+sortedcontainers
+g2p_en
+g2p
+resampy
\ No newline at end of file