mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Bug fix regarding bias
This commit is contained in:
parent
a0b2276f68
commit
1dbe1e4086
@ -432,8 +432,8 @@ class MaxEigLimiterFunction(torch.autograd.Function):
|
|||||||
|
|
||||||
class BasicNormFunction(torch.autograd.Function):
|
class BasicNormFunction(torch.autograd.Function):
|
||||||
# This computes:
|
# This computes:
|
||||||
# scales = torch.mean((x + bias) ** 2, keepdim=True) + eps.exp()
|
# scales = torch.mean((x - bias) ** 2, keepdim=True) + eps.exp()
|
||||||
# return x * scales
|
# return (x - bias) * scales
|
||||||
# (after unsqueezing the bias), but it does it in a memory-efficient way so that
|
# (after unsqueezing the bias), but it does it in a memory-efficient way so that
|
||||||
# it can just store the returned value (chances are, this will also be needed for
|
# it can just store the returned value (chances are, this will also be needed for
|
||||||
# some other reason, related to the next operation, so we can save memory).
|
# some other reason, related to the next operation, so we can save memory).
|
||||||
@ -448,8 +448,8 @@ class BasicNormFunction(torch.autograd.Function):
|
|||||||
ctx.channel_dim = channel_dim
|
ctx.channel_dim = channel_dim
|
||||||
for _ in range(channel_dim + 1, x.ndim):
|
for _ in range(channel_dim + 1, x.ndim):
|
||||||
bias = bias.unsqueeze(-1)
|
bias = bias.unsqueeze(-1)
|
||||||
scales = (torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5
|
scales = (torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5
|
||||||
ans = x * scales - bias
|
ans = x * scales
|
||||||
ctx.save_for_backward(ans if store_output_for_backprop else x,
|
ctx.save_for_backward(ans if store_output_for_backprop else x,
|
||||||
scales, bias, eps)
|
scales, bias, eps)
|
||||||
return ans
|
return ans
|
||||||
@ -470,8 +470,8 @@ class BasicNormFunction(torch.autograd.Function):
|
|||||||
eps.requires_grad = True
|
eps.requires_grad = True
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
# recompute scales from x, bias and eps.
|
# recompute scales from x, bias and eps.
|
||||||
scales = (torch.mean((x + bias) ** 2, dim=ctx.channel_dim, keepdim=True) + eps.exp()) ** -0.5
|
scales = (torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) + eps.exp()) ** -0.5
|
||||||
ans = x * scales - bias
|
ans = x * scales
|
||||||
ans.backward(gradient=ans_grad)
|
ans.backward(gradient=ans_grad)
|
||||||
return x.grad, bias.grad.flatten(), eps.grad, None, None
|
return x.grad, bias.grad.flatten(), eps.grad, None, None
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user