From 595d4a3c47f5d3cd24f60fe071011aa458ac1716 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Wed, 21 Feb 2024 17:50:19 +0800 Subject: [PATCH] modify usage of tokenizer in vits/train.py --- egs/ljspeech/TTS/vits/train.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/egs/ljspeech/TTS/vits/train.py b/egs/ljspeech/TTS/vits/train.py index 71c4224fa..6589b75ff 100755 --- a/egs/ljspeech/TTS/vits/train.py +++ b/egs/ljspeech/TTS/vits/train.py @@ -296,14 +296,16 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device): features_lens = batch["features_lens"].to(device) tokens = batch["tokens"] - tokens = tokenizer.tokens_to_token_ids(tokens) + tokens = tokenizer.tokens_to_token_ids( + tokens, intersperse_blank=True, add_sos=True, add_eos=True + ) tokens = k2.RaggedTensor(tokens) row_splits = tokens.shape.row_splits(1) tokens_lens = row_splits[1:] - row_splits[:-1] tokens = tokens.to(device) tokens_lens = tokens_lens.to(device) # a tensor of shape (B, T) - tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) return audio, audio_lens, features, features_lens, tokens, tokens_lens @@ -742,8 +744,7 @@ def run(rank, world_size, args): logging.info(f"Device: {device}") tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.blank_id - params.oov_id = tokenizer.oov_id + params.blank_id = tokenizer.pad_id params.vocab_size = tokenizer.vocab_size logging.info(params)