mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Subtract bias after scaling
This commit is contained in:
parent
d31e2e12c6
commit
a0b2276f68
@ -449,7 +449,7 @@ class BasicNormFunction(torch.autograd.Function):
|
||||
for _ in range(channel_dim + 1, x.ndim):
|
||||
bias = bias.unsqueeze(-1)
|
||||
scales = (torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5
|
||||
ans = x * scales
|
||||
ans = x * scales - bias
|
||||
ctx.save_for_backward(ans if store_output_for_backprop else x,
|
||||
scales, bias, eps)
|
||||
return ans
|
||||
@ -471,7 +471,7 @@ class BasicNormFunction(torch.autograd.Function):
|
||||
with torch.enable_grad():
|
||||
# recompute scales from x, bias and eps.
|
||||
scales = (torch.mean((x + bias) ** 2, dim=ctx.channel_dim, keepdim=True) + eps.exp()) ** -0.5
|
||||
ans = x * scales
|
||||
ans = x * scales - bias
|
||||
ans.backward(gradient=ans_grad)
|
||||
return x.grad, bias.grad.flatten(), eps.grad, None, None
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user