Add bias to BasicNorm

This commit is contained in:
Daniel Povey 2022-12-22 15:14:49 +08:00
parent b39cde85c8
commit 903955f5d9

View File

@ -430,32 +430,44 @@ class MaxEigLimiterFunction(torch.autograd.Function):
return x_grad + x_extra_grad.detach(), None, None, None, None return x_grad + x_extra_grad.detach(), None, None, None, None
class ComputeSquaredMeanWithOffset(torch.autograd.Function): class BasicNormFunction(torch.autograd.Function):
# This computes:
# scales = torch.mean((x + bias) ** 2, keepdim=True) + eps.exp()
# return x * scales
# (after unsqueezing the bias), but it does it in a memory-efficient way so that
# it can just store the returned value (chances are, this will also be needed for
# some other reason, related to the next operation, so we can save memory).
@staticmethod @staticmethod
@custom_fwd @custom_fwd
def forward(ctx, x: Tensor, bias: Tensor, channel_dim: int) -> Tensor: def forward(ctx, x: Tensor, bias: Tensor, eps: Tensor, channel_dim: int) -> Tensor:
assert bias.ndim == 1 assert bias.ndim == 1
if channel_dim < 0: if channel_dim < 0:
channel_dim = channel_dim + x.ndim channel_dim = channel_dim + x.ndim
ctx.channel_dim = channel_dim ctx.channel_dim = channel_dim
for _ in range(channel_dim + 1, x.ndim): for _ in range(channel_dim + 1, x.ndim):
bias = bias.unsqueeze(-1) bias = bias.unsqueeze(-1)
ans = torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True) scales = (torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5
ctx.save_for_backward(x, bias) ans = x * scales
ctx.save_for_backward(ans, scales, bias, eps)
return ans return ans
@staticmethod @staticmethod
@custom_bwd @custom_bwd
def backward(ctx, ans_grad: Tensor) -> Tensor: def backward(ctx, ans_grad: Tensor) -> Tensor:
x, bias = ctx.saved_tensors ans, scales, bias, eps = ctx.saved_tensors
x = ans / scales
x = x.detach() x = x.detach()
bias = bias.detach() bias = bias.detach()
eps = eps.detach()
x.requires_grad = True x.requires_grad = True
bias.requires_grad = True bias.requires_grad = True
eps.requires_grad = True
with torch.enable_grad(): with torch.enable_grad():
ans = torch.mean((x + bias) ** 2, dim=ctx.channel_dim, keepdim=True) # recompute scales from x, bias and eps.
scales = (torch.mean((x + bias) ** 2, dim=ctx.channel_dim, keepdim=True) + eps.exp()) ** -0.5
ans = x * scales
ans.backward(gradient=ans_grad) ans.backward(gradient=ans_grad)
return x.grad, bias.grad.flatten(), None return x.grad, bias.grad.flatten(), eps.grad, None
@ -512,6 +524,16 @@ class BasicNorm(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
assert x.shape[self.channel_dim] == self.num_channels assert x.shape[self.channel_dim] == self.num_channels
eps = self.eps eps = self.eps
if torch.jit.is_scripting():
channel_dim = self.channel_dim
if channel_dim < 0:
channel_dim = channel_dim + x.ndim
for _ in range(channel_dim + 1, x.ndim):
bias = bias.unsqueeze(-1)
scales = (torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5
return x * scales
if self.training and random.random() < 0.25: if self.training and random.random() < 0.25:
# with probability 0.25, in training mode, clamp eps between the min # with probability 0.25, in training mode, clamp eps between the min
# and max; this will encourage it to learn parameters within the # and max; this will encourage it to learn parameters within the
@ -522,11 +544,7 @@ class BasicNorm(torch.nn.Module):
# region if it happens to exit it. # region if it happens to exit it.
eps = eps.clamp(min=self.eps_min, max=self.eps_max) eps = eps.clamp(min=self.eps_min, max=self.eps_max)
norms = ComputeSquaredMeanWithOffset.apply(x, self.bias, self.channel_dim) return BasicNormFunction.apply(x, self.bias, eps, self.channel_dim)
scales = (
norms + eps.exp()
) ** -0.5
return x * scales