Remove bias from BasicNorm, add an eps instead.

This commit is contained in:
Daniel Povey 2022-12-30 15:13:00 +08:00
parent da0623aa7f
commit 851912c581

View File

@ -434,46 +434,43 @@ class MaxEigLimiterFunction(torch.autograd.Function):
class BasicNormFunction(torch.autograd.Function):
# This computes:
# scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp()
# scales = ((torch.mean(x**2, keepdim=True) + eps) ** -0.5 * scale)
# return (x - bias) * scales
# (after unsqueezing the bias), but it does it in a memory-efficient way so that
# it can just store the returned value (chances are, this will also be needed for
# some other reason, related to the next operation, so we can save memory).
@staticmethod
@custom_fwd
def forward(ctx, x: Tensor, bias: Tensor, log_scale: Tensor, channel_dim: int,
def forward(ctx, x: Tensor, eps: Tensor, scale: Tensor, channel_dim: int,
store_output_for_backprop: bool) -> Tensor:
assert bias.ndim == 1
if channel_dim < 0:
channel_dim = channel_dim + x.ndim
ctx.store_output_for_backprop = store_output_for_backprop
ctx.channel_dim = channel_dim
for _ in range(channel_dim + 1, x.ndim):
bias = bias.unsqueeze(-1)
scales = (torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * log_scale.exp()
scales = ((torch.mean(x ** 2, dim=channel_dim, keepdim=True) + eps) ** -0.5) * scale
ans = x * scales
ctx.save_for_backward(ans.detach() if store_output_for_backprop else x,
scales.detach(), bias.detach(), log_scale.detach())
scales.detach(), eps.detach(), scale.detach())
return ans
@staticmethod
@custom_bwd
def backward(ctx, ans_grad: Tensor) -> Tensor:
ans_or_x, scales, bias, log_scale = ctx.saved_tensors
ans_or_x, scales, eps, scale = ctx.saved_tensors
if ctx.store_output_for_backprop:
x = ans_or_x / scales
else:
x = ans_or_x
x = x.detach()
x.requires_grad = True
bias.requires_grad = True
log_scale.requires_grad = True
with torch.enable_grad():
# recompute scales from x, bias and log_scale.
scales = (torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5) * log_scale.exp()
ans = x * scales
ans.backward(gradient=ans_grad)
return x.grad, bias.grad.flatten(), log_scale.grad, None, None
with torch.cuda.amp.autocast(enabled=False):
assert eps.dtype != torch.float16 and scale.dtype != torch.float16
x = x.to(torch.float32).detach()
x.requires_grad = True
eps.requires_grad = True
scale.requires_grad = True
with torch.enable_grad():
# recompute scales from x, epsand log_scale.
scales = ((torch.mean(x ** 2, dim=ctx.channel_dim, keepdim=True) + eps) ** -0.5) * scale
ans = x * scales
ans.backward(gradient=ans_grad.to(torch.float32))
return x.grad.to(ans_grad.dtype), eps.grad, scale.grad, None, None
@ -500,8 +497,11 @@ class BasicNorm(torch.nn.Module):
{-2, -1, 0, 1, 2, 3}.
log_scale: the initial log-scale that we multiply the output by; this
is learnable.
eps: the initial epsilon value (not in log space)
log_scale_min: FloatLike, minimum allowed value of log_scale
log_scale_max: FloatLike, maximum allowed value of log_scale
log_eps_min: FloatLike, minimum allowed value of log_eps
log_eps_max: FloatLike, maximum allowed value of log_eps
store_output_for_backprop: only possibly affects memory use; recommend
to set to True if you think the output of this module is more likely
than the input of this module to be required to be stored for the
@ -512,18 +512,26 @@ class BasicNorm(torch.nn.Module):
self,
num_channels: int,
channel_dim: int = -1, # CAUTION: see documentation.
log_scale: float = 1.0,
log_scale_min: float = -1.5,
log_scale_max: float = 1.5,
log_scale: float = 2.0,
eps: float = 0.25,
log_scale_min: FloatLike = -1.5,
log_scale_max: FloatLike = 1.5,
log_eps_min: FloatLike = -3.0,
log_eps_max: FloatLike = 3.0,
store_output_for_backprop: bool = False
) -> None:
super(BasicNorm, self).__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim
self.log_scale = nn.Parameter(torch.tensor(log_scale))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.log_eps = nn.Parameter(torch.tensor(eps).log().detach())
self.log_scale_min = log_scale_min
self.log_scale_max = log_scale_max
self.log_eps_min = log_eps_min
self.log_eps_max = log_eps_max
self.store_output_for_backprop = store_output_for_backprop
def forward(self, x: Tensor) -> Tensor:
@ -537,7 +545,7 @@ class BasicNorm(torch.nn.Module):
bias = self.bias
for _ in range(channel_dim + 1, x.ndim):
bias = bias.unsqueeze(-1)
scales = ((torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) *
scales = (((torch.mean(x ** 2, dim=channel_dim, keepdim=True) + self.eps.exp()) ** -0.5) *
self.log_scale.exp())
return x * scales
@ -545,8 +553,12 @@ class BasicNorm(torch.nn.Module):
min=float(self.log_scale_min),
max=float(self.log_scale_max),
training=self.training)
log_eps = limit_param_value(self.log_eps,
min=float(self.log_eps_min),
max=float(self.log_eps_max),
training=self.training)
return BasicNormFunction.apply(x, self.bias, log_scale,
return BasicNormFunction.apply(x, log_eps.exp(), log_scale.exp(),
self.channel_dim,
self.store_output_for_backprop)