Update train.py

This commit is contained in:
Yifan Yang 2023-06-15 16:54:15 +08:00 committed by GitHub
parent 829f188816
commit e307ac21c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -789,7 +789,7 @@ def compute_loss(
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int) y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( simple_loss, pruned_loss, ctc_loss = model(