From 6ea1706e113f3d554d8793c3d0e9d3838cea3f55 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 14 Nov 2022 23:31:00 +0800 Subject: [PATCH] Fix potential/theoretical issue in backward of LimitParamValue --- egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index cf2f7f3aa..174ccf39e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -728,7 +728,7 @@ class LimitParamValue(torch.autograd.Function): x, = ctx.saved_tensors # where x < ctx.min, ensure all grads are negative (this will tend to make # x more positive). - x_grad *= torch.where(torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0) + x_grad = x_grad * torch.where(torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0) # where x > ctx.max, ensure all grads are positive (this will tend to make # x more negative). x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0)