diff --git a/egs/librispeech/ASR/transducer/model.py b/egs/librispeech/ASR/transducer/model.py index e5563319e..9b5e16084 100644 --- a/egs/librispeech/ASR/transducer/model.py +++ b/egs/librispeech/ASR/transducer/model.py @@ -108,10 +108,11 @@ class Transducer(nn.Module): y_padded = y_padded.to(torch.int64) simple_loss, (px_grad, py_grad) = k2.rnnt_loss_simple( - decoder_out, encoder_out, y_padded, blank_id, boundary, True) + decoder_out, encoder_out, y_padded, blank_id, boundary, + return_grad=True) # TODO: make s_range configurable - ranges = k2.get_rnnt_prune_ranges(px_grad, py_grad, x_lens, s_range=5) + ranges = k2.get_rnnt_prune_ranges(px_grad, py_grad, boundary, s_range=5) am_pruned, lm_pruned = k2.do_rnnt_pruning( encoder_out, decoder_out, ranges)