Minor fixes

This commit is contained in:
pkufool 2021-12-24 08:05:41 +08:00
parent 2d6631ac76
commit eb29202168

View File

@ -108,10 +108,11 @@ class Transducer(nn.Module):
y_padded = y_padded.to(torch.int64)
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_simple(
decoder_out, encoder_out, y_padded, blank_id, boundary, True)
decoder_out, encoder_out, y_padded, blank_id, boundary,
return_grad=True)
# TODO: make s_range configurable
ranges = k2.get_rnnt_prune_ranges(px_grad, py_grad, x_lens, s_range=5)
ranges = k2.get_rnnt_prune_ranges(px_grad, py_grad, boundary, s_range=5)
am_pruned, lm_pruned = k2.do_rnnt_pruning(
encoder_out, decoder_out, ranges)