mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add bias to BasicNorm
This commit is contained in:
parent
b39cde85c8
commit
903955f5d9
@ -430,32 +430,44 @@ class MaxEigLimiterFunction(torch.autograd.Function):
|
|||||||
return x_grad + x_extra_grad.detach(), None, None, None, None
|
return x_grad + x_extra_grad.detach(), None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
class ComputeSquaredMeanWithOffset(torch.autograd.Function):
|
class BasicNormFunction(torch.autograd.Function):
|
||||||
|
# This computes:
|
||||||
|
# scales = torch.mean((x + bias) ** 2, keepdim=True) + eps.exp()
|
||||||
|
# return x * scales
|
||||||
|
# (after unsqueezing the bias), but it does it in a memory-efficient way so that
|
||||||
|
# it can just store the returned value (chances are, this will also be needed for
|
||||||
|
# some other reason, related to the next operation, so we can save memory).
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd
|
@custom_fwd
|
||||||
def forward(ctx, x: Tensor, bias: Tensor, channel_dim: int) -> Tensor:
|
def forward(ctx, x: Tensor, bias: Tensor, eps: Tensor, channel_dim: int) -> Tensor:
|
||||||
assert bias.ndim == 1
|
assert bias.ndim == 1
|
||||||
if channel_dim < 0:
|
if channel_dim < 0:
|
||||||
channel_dim = channel_dim + x.ndim
|
channel_dim = channel_dim + x.ndim
|
||||||
ctx.channel_dim = channel_dim
|
ctx.channel_dim = channel_dim
|
||||||
for _ in range(channel_dim + 1, x.ndim):
|
for _ in range(channel_dim + 1, x.ndim):
|
||||||
bias = bias.unsqueeze(-1)
|
bias = bias.unsqueeze(-1)
|
||||||
ans = torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True)
|
scales = (torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5
|
||||||
ctx.save_for_backward(x, bias)
|
ans = x * scales
|
||||||
|
ctx.save_for_backward(ans, scales, bias, eps)
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_bwd
|
@custom_bwd
|
||||||
def backward(ctx, ans_grad: Tensor) -> Tensor:
|
def backward(ctx, ans_grad: Tensor) -> Tensor:
|
||||||
x, bias = ctx.saved_tensors
|
ans, scales, bias, eps = ctx.saved_tensors
|
||||||
|
x = ans / scales
|
||||||
x = x.detach()
|
x = x.detach()
|
||||||
bias = bias.detach()
|
bias = bias.detach()
|
||||||
|
eps = eps.detach()
|
||||||
x.requires_grad = True
|
x.requires_grad = True
|
||||||
bias.requires_grad = True
|
bias.requires_grad = True
|
||||||
|
eps.requires_grad = True
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
ans = torch.mean((x + bias) ** 2, dim=ctx.channel_dim, keepdim=True)
|
# recompute scales from x, bias and eps.
|
||||||
|
scales = (torch.mean((x + bias) ** 2, dim=ctx.channel_dim, keepdim=True) + eps.exp()) ** -0.5
|
||||||
|
ans = x * scales
|
||||||
ans.backward(gradient=ans_grad)
|
ans.backward(gradient=ans_grad)
|
||||||
return x.grad, bias.grad.flatten(), None
|
return x.grad, bias.grad.flatten(), eps.grad, None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -512,6 +524,16 @@ class BasicNorm(torch.nn.Module):
|
|||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
assert x.shape[self.channel_dim] == self.num_channels
|
assert x.shape[self.channel_dim] == self.num_channels
|
||||||
eps = self.eps
|
eps = self.eps
|
||||||
|
|
||||||
|
if torch.jit.is_scripting():
|
||||||
|
channel_dim = self.channel_dim
|
||||||
|
if channel_dim < 0:
|
||||||
|
channel_dim = channel_dim + x.ndim
|
||||||
|
for _ in range(channel_dim + 1, x.ndim):
|
||||||
|
bias = bias.unsqueeze(-1)
|
||||||
|
scales = (torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5
|
||||||
|
return x * scales
|
||||||
|
|
||||||
if self.training and random.random() < 0.25:
|
if self.training and random.random() < 0.25:
|
||||||
# with probability 0.25, in training mode, clamp eps between the min
|
# with probability 0.25, in training mode, clamp eps between the min
|
||||||
# and max; this will encourage it to learn parameters within the
|
# and max; this will encourage it to learn parameters within the
|
||||||
@ -522,11 +544,7 @@ class BasicNorm(torch.nn.Module):
|
|||||||
# region if it happens to exit it.
|
# region if it happens to exit it.
|
||||||
eps = eps.clamp(min=self.eps_min, max=self.eps_max)
|
eps = eps.clamp(min=self.eps_min, max=self.eps_max)
|
||||||
|
|
||||||
norms = ComputeSquaredMeanWithOffset.apply(x, self.bias, self.channel_dim)
|
return BasicNormFunction.apply(x, self.bias, eps, self.channel_dim)
|
||||||
scales = (
|
|
||||||
norms + eps.exp()
|
|
||||||
) ** -0.5
|
|
||||||
return x * scales
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user