mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Minor fixes
This commit is contained in:
parent
2d6631ac76
commit
eb29202168
@ -108,10 +108,11 @@ class Transducer(nn.Module):
|
||||
y_padded = y_padded.to(torch.int64)
|
||||
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_simple(
|
||||
decoder_out, encoder_out, y_padded, blank_id, boundary, True)
|
||||
decoder_out, encoder_out, y_padded, blank_id, boundary,
|
||||
return_grad=True)
|
||||
|
||||
# TODO: make s_range configurable
|
||||
ranges = k2.get_rnnt_prune_ranges(px_grad, py_grad, x_lens, s_range=5)
|
||||
ranges = k2.get_rnnt_prune_ranges(px_grad, py_grad, boundary, s_range=5)
|
||||
|
||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||
encoder_out, decoder_out, ranges)
|
||||
|
Loading…
x
Reference in New Issue
Block a user