From a0b2276f68a8a0b4eabf6526fb2831e248550bf3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 22 Dec 2022 15:45:45 +0800 Subject: [PATCH] Subtract bias after scaling --- egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index f370f3386..706c3c1c7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -449,7 +449,7 @@ class BasicNormFunction(torch.autograd.Function): for _ in range(channel_dim + 1, x.ndim): bias = bias.unsqueeze(-1) scales = (torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5 - ans = x * scales + ans = x * scales - bias ctx.save_for_backward(ans if store_output_for_backprop else x, scales, bias, eps) return ans @@ -471,7 +471,7 @@ class BasicNormFunction(torch.autograd.Function): with torch.enable_grad(): # recompute scales from x, bias and eps. scales = (torch.mean((x + bias) ** 2, dim=ctx.channel_dim, keepdim=True) + eps.exp()) ** -0.5 - ans = x * scales + ans = x * scales - bias ans.backward(gradient=ans_grad) return x.grad, bias.grad.flatten(), eps.grad, None, None