diff --git a/egs/vctk/TTS/vits/train.py b/egs/vctk/TTS/vits/train.py index 3fc637a46..d9f250cc9 100755 --- a/egs/vctk/TTS/vits/train.py +++ b/egs/vctk/TTS/vits/train.py @@ -338,7 +338,7 @@ def prepare_input( audio_lens = batch["audio_lens"].to(device) features_lens = batch["features_lens"].to(device) tokens = batch["tokens"] - speakers = torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).to(device) + speakers = torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).int().to(device) tokens = tokenizer.tokens_to_token_ids(tokens) tokens = k2.RaggedTensor(tokens)