From 16ccbc5bc87b04ad6e0a16ce4846501e35a02973 Mon Sep 17 00:00:00 2001 From: Erwan Date: Fri, 1 Mar 2024 10:31:26 +0100 Subject: [PATCH] Add new phonemizer to infer.py --- egs/ljspeech/TTS/vits2/infer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/egs/ljspeech/TTS/vits2/infer.py b/egs/ljspeech/TTS/vits2/infer.py index cf0d20ae2..9e7c71c6d 100755 --- a/egs/ljspeech/TTS/vits2/infer.py +++ b/egs/ljspeech/TTS/vits2/infer.py @@ -130,14 +130,16 @@ def infer_dataset( batch_size = len(batch["tokens"]) tokens = batch["tokens"] - tokens = tokenizer.tokens_to_token_ids(tokens) + tokens = tokenizer.tokens_to_token_ids( + tokens, intersperse_blank=True, add_sos=True, add_eos=True + ) tokens = k2.RaggedTensor(tokens) row_splits = tokens.shape.row_splits(1) tokens_lens = row_splits[1:] - row_splits[:-1] tokens = tokens.to(device) tokens_lens = tokens_lens.to(device) # tensor of shape (B, T) - tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) audio = batch["audio"] audio_lens = batch["audio_lens"].tolist() @@ -201,8 +203,7 @@ def main(): device = torch.device("cuda", 0) tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.blank_id - params.oov_id = tokenizer.oov_id + params.blank_id = tokenizer.pad_id params.vocab_size = tokenizer.vocab_size logging.info(f"Device: {device}")