Fix for diagnostic (#1135)

* CTC loss return tensor

* Update model.py
This commit is contained in:
Yifan Yang 2023-06-16 15:04:41 +08:00 committed by GitHub
parent 0a465794a8
commit d667dc365b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -340,8 +340,8 @@ class AsrModel(nn.Module):
lm_scale=lm_scale, lm_scale=lm_scale,
) )
else: else:
simple_loss = 0 simple_loss = torch.empty(0)
pruned_loss = 0 pruned_loss = torch.empty(0)
if self.use_ctc: if self.use_ctc:
# Compute CTC loss # Compute CTC loss
@ -353,6 +353,6 @@ class AsrModel(nn.Module):
target_lengths=y_lens, target_lengths=y_lens,
) )
else: else:
ctc_loss = 0 ctc_loss = torch.empty(0)
return simple_loss, pruned_loss, ctc_loss return simple_loss, pruned_loss, ctc_loss