diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index c569fafad..698db781c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -555,7 +555,14 @@ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor: it shouldn't really matter, or may even be helpful; we just use this to disallow really implausible values of scores to be given to softmax. """ - aux_loss = penalty * (x.abs() - limit).relu() + x_sign = x.sign() + over_limit = (x.abs() - limit) > 0 + # The following is a memory efficient way to penalize the absolute values of + # x that's over the limit. the numerical value of aux_loss as computed here will actually be + # larger than it should be, but it has the same derivative as + # penalty * (x.abs() - limit).relu() + # which is what we really want to penalize + aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) # note: we don't do sum() here on aux)_loss, but it's as if we had done # sum() due to how with_loss() works. x = with_loss(x, aux_loss)