From 6a6df19bde0561815a9cf0ba60fd753ac0073f51 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 25 Oct 2022 18:34:04 +0800 Subject: [PATCH] Hopefully make penalize_abs_values_gt more memory efficient. --- .../ASR/pruned_transducer_stateless7/scaling.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index c569fafad..698db781c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -555,7 +555,14 @@ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor: it shouldn't really matter, or may even be helpful; we just use this to disallow really implausible values of scores to be given to softmax. """ - aux_loss = penalty * (x.abs() - limit).relu() + x_sign = x.sign() + over_limit = (x.abs() - limit) > 0 + # The following is a memory efficient way to penalize the absolute values of + # x that's over the limit. the numerical value of aux_loss as computed here will actually be + # larger than it should be, but it has the same derivative as + # penalty * (x.abs() - limit).relu() + # which is what we really want to penalize + aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) # note: we don't do sum() here on aux)_loss, but it's as if we had done # sum() due to how with_loss() works. x = with_loss(x, aux_loss)