Update train.py

This commit is contained in:
jinzr 2023-11-30 22:50:29 +08:00
parent cbf8b2d36c
commit 0ef3da24c1

View File

@ -338,7 +338,7 @@ def prepare_input(
audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device)
tokens = batch["tokens"]
speakers = torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).to(device)
speakers = torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).int().to(device)
tokens = tokenizer.tokens_to_token_ids(tokens)
tokens = k2.RaggedTensor(tokens)