Implement bias in BasicNorm

This commit is contained in:
Daniel Povey 2022-12-22 14:59:29 +08:00
parent 5aa874d8e3
commit b39cde85c8

View File

@ -430,6 +430,35 @@ class MaxEigLimiterFunction(torch.autograd.Function):
return x_grad + x_extra_grad.detach(), None, None, None, None return x_grad + x_extra_grad.detach(), None, None, None, None
class ComputeSquaredMeanWithOffset(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, x: Tensor, bias: Tensor, channel_dim: int) -> Tensor:
assert bias.ndim == 1
if channel_dim < 0:
channel_dim = channel_dim + x.ndim
ctx.channel_dim = channel_dim
for _ in range(channel_dim + 1, x.ndim):
bias = bias.unsqueeze(-1)
ans = torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True)
ctx.save_for_backward(x, bias)
return ans
@staticmethod
@custom_bwd
def backward(ctx, ans_grad: Tensor) -> Tensor:
x, bias = ctx.saved_tensors
x = x.detach()
bias = bias.detach()
x.requires_grad = True
bias.requires_grad = True
with torch.enable_grad():
ans = torch.mean((x + bias) ** 2, dim=ctx.channel_dim, keepdim=True)
ans.backward(gradient=ans_grad)
return x.grad, bias.grad.flatten(), None
class BasicNorm(torch.nn.Module): class BasicNorm(torch.nn.Module):
""" """
This is intended to be a simpler, and hopefully cheaper, replacement for This is intended to be a simpler, and hopefully cheaper, replacement for
@ -476,6 +505,7 @@ class BasicNorm(torch.nn.Module):
self.eps = nn.Parameter(torch.tensor(eps).log().detach()) self.eps = nn.Parameter(torch.tensor(eps).log().detach())
else: else:
self.register_buffer("eps", torch.tensor(eps).log().detach()) self.register_buffer("eps", torch.tensor(eps).log().detach())
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps_min = eps_min self.eps_min = eps_min
self.eps_max = eps_max self.eps_max = eps_max
@ -491,8 +521,10 @@ class BasicNorm(torch.nn.Module):
# gradients to allow the parameter to get back into the allowed # gradients to allow the parameter to get back into the allowed
# region if it happens to exit it. # region if it happens to exit it.
eps = eps.clamp(min=self.eps_min, max=self.eps_max) eps = eps.clamp(min=self.eps_min, max=self.eps_max)
norms = ComputeSquaredMeanWithOffset.apply(x, self.bias, self.channel_dim)
scales = ( scales = (
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp() norms + eps.exp()
) ** -0.5 ) ** -0.5
return x * scales return x * scales
@ -1071,11 +1103,9 @@ class Whiten(nn.Module):
self.grad_scale = grad_scale self.grad_scale = grad_scale
if isinstance(prob, float): if isinstance(prob, float):
assert 0 < prob <= 1 prob = (prob, prob)
self.prob = prob
else:
(self.min_prob, self.max_prob) = prob (self.min_prob, self.max_prob) = prob
assert 0 < self.min_prob < self.max_prob <= 1 assert 0 < self.min_prob <= self.max_prob <= 1
self.prob = self.max_prob self.prob = self.max_prob
self.name = None # will be set in training loop self.name = None # will be set in training loop
@ -1718,7 +1748,7 @@ def _test_activation_balancer_magnitude():
max_factor=0.2, max_factor=0.2,
min_abs=0.2, min_abs=0.2,
max_abs=0.8, max_abs=0.8,
min_prob=1.0, prob=1.0,
) )
y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) y_grad = torch.sign(torch.randn(magnitudes.numel(), N))