Subtract bias after scaling

This commit is contained in:
Daniel Povey 2022-12-22 15:45:45 +08:00
parent d31e2e12c6
commit a0b2276f68

View File

@ -449,7 +449,7 @@ class BasicNormFunction(torch.autograd.Function):
for _ in range(channel_dim + 1, x.ndim):
bias = bias.unsqueeze(-1)
scales = (torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5
ans = x * scales
ans = x * scales - bias
ctx.save_for_backward(ans if store_output_for_backprop else x,
scales, bias, eps)
return ans
@ -471,7 +471,7 @@ class BasicNormFunction(torch.autograd.Function):
with torch.enable_grad():
# recompute scales from x, bias and eps.
scales = (torch.mean((x + bias) ** 2, dim=ctx.channel_dim, keepdim=True) + eps.exp()) ** -0.5
ans = x * scales
ans = x * scales - bias
ans.backward(gradient=ans_grad)
return x.grad, bias.grad.flatten(), eps.grad, None, None