From 1dbe1e4086fb0f45437e14dec1d29a7bd84d2978 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 22 Dec 2022 17:35:29 +0800 Subject: [PATCH] Bug fix regarding bias --- .../ASR/pruned_transducer_stateless7/scaling.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 706c3c1c7..a4bb01daa 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -432,8 +432,8 @@ class MaxEigLimiterFunction(torch.autograd.Function): class BasicNormFunction(torch.autograd.Function): # This computes: - # scales = torch.mean((x + bias) ** 2, keepdim=True) + eps.exp() - # return x * scales + # scales = torch.mean((x - bias) ** 2, keepdim=True) + eps.exp() + # return (x - bias) * scales # (after unsqueezing the bias), but it does it in a memory-efficient way so that # it can just store the returned value (chances are, this will also be needed for # some other reason, related to the next operation, so we can save memory). @@ -448,8 +448,8 @@ class BasicNormFunction(torch.autograd.Function): ctx.channel_dim = channel_dim for _ in range(channel_dim + 1, x.ndim): bias = bias.unsqueeze(-1) - scales = (torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5 - ans = x * scales - bias + scales = (torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5 + ans = x * scales ctx.save_for_backward(ans if store_output_for_backprop else x, scales, bias, eps) return ans @@ -470,8 +470,8 @@ class BasicNormFunction(torch.autograd.Function): eps.requires_grad = True with torch.enable_grad(): # recompute scales from x, bias and eps. - scales = (torch.mean((x + bias) ** 2, dim=ctx.channel_dim, keepdim=True) + eps.exp()) ** -0.5 - ans = x * scales - bias + scales = (torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) + eps.exp()) ** -0.5 + ans = x * scales ans.backward(gradient=ans_grad) return x.grad, bias.grad.flatten(), eps.grad, None, None