diff --git a/egs/aishell/ASR/transducer_stateless/model.py b/egs/aishell/ASR/transducer_stateless/model.py index 0322edeed..c19325a15 100644 --- a/egs/aishell/ASR/transducer_stateless/model.py +++ b/egs/aishell/ASR/transducer_stateless/model.py @@ -122,4 +122,4 @@ class Transducer(nn.Module): loss = k2.rnnt_loss(logits, y_padded, blank_id, boundary) - return torch.sum(loss) + return loss