mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
Minor fixes
This commit is contained in:
parent
2d6631ac76
commit
eb29202168
@ -108,10 +108,11 @@ class Transducer(nn.Module):
|
|||||||
y_padded = y_padded.to(torch.int64)
|
y_padded = y_padded.to(torch.int64)
|
||||||
|
|
||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_simple(
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_simple(
|
||||||
decoder_out, encoder_out, y_padded, blank_id, boundary, True)
|
decoder_out, encoder_out, y_padded, blank_id, boundary,
|
||||||
|
return_grad=True)
|
||||||
|
|
||||||
# TODO: make s_range configurable
|
# TODO: make s_range configurable
|
||||||
ranges = k2.get_rnnt_prune_ranges(px_grad, py_grad, x_lens, s_range=5)
|
ranges = k2.get_rnnt_prune_ranges(px_grad, py_grad, boundary, s_range=5)
|
||||||
|
|
||||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||||
encoder_out, decoder_out, ranges)
|
encoder_out, decoder_out, ranges)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user