Fix for diagnostic (#1135)

* CTC loss return tensor

* Update model.py
This commit is contained in:
Yifan Yang 2023-06-16 15:04:41 +08:00 committed by GitHub
parent 0a465794a8
commit d667dc365b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -340,8 +340,8 @@ class AsrModel(nn.Module):
lm_scale=lm_scale,
)
else:
simple_loss = 0
pruned_loss = 0
simple_loss = torch.empty(0)
pruned_loss = torch.empty(0)
if self.use_ctc:
# Compute CTC loss
@ -353,6 +353,6 @@ class AsrModel(nn.Module):
target_lengths=y_lens,
)
else:
ctc_loss = 0
ctc_loss = torch.empty(0)
return simple_loss, pruned_loss, ctc_loss