From 593a6e946db02b4a0a1731bca1a361eb760a48c4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 15 Oct 2022 15:36:55 +0800 Subject: [PATCH] Fix an issue with scaling of grad. --- egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index e9fc91f24..21653e54d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -859,7 +859,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): (metric - ctx.whitening_limit).relu().backward() penalty_grad = x_detached.grad - scale = ctx.grad_scale * (x.to(torch.float32).norm() / + scale = ctx.grad_scale * (x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20)) penalty_grad = penalty_grad * scale return x_grad + penalty_grad.to(x_grad.dtype), None, None, None