Update train.py

This commit is contained in:
jinzr 2023-11-30 23:04:11 +08:00
parent 0ef3da24c1
commit 48c90df9c9

View File

@ -338,7 +338,9 @@ def prepare_input(
audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device)
tokens = batch["tokens"]
speakers = torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).int().to(device)
speakers = (
torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).int().to(device)
)
tokens = tokenizer.tokens_to_token_ids(tokens)
tokens = k2.RaggedTensor(tokens)