diff --git a/egs/vctk/TTS/vits/train.py b/egs/vctk/TTS/vits/train.py index d9f250cc9..9e273d42a 100755 --- a/egs/vctk/TTS/vits/train.py +++ b/egs/vctk/TTS/vits/train.py @@ -338,7 +338,9 @@ def prepare_input( audio_lens = batch["audio_lens"].to(device) features_lens = batch["features_lens"].to(device) tokens = batch["tokens"] - speakers = torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).int().to(device) + speakers = ( + torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).int().to(device) + ) tokens = tokenizer.tokens_to_token_ids(tokens) tokens = k2.RaggedTensor(tokens)