Remove bias from BasicNorm, add an eps instead.

This commit is contained in:
Daniel Povey 2022-12-30 15:13:00 +08:00
parent da0623aa7f
commit 851912c581

View File

@ -434,46 +434,43 @@ class MaxEigLimiterFunction(torch.autograd.Function):
class BasicNormFunction(torch.autograd.Function): class BasicNormFunction(torch.autograd.Function):
# This computes: # This computes:
# scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp() # scales = ((torch.mean(x**2, keepdim=True) + eps) ** -0.5 * scale)
# return (x - bias) * scales # return (x - bias) * scales
# (after unsqueezing the bias), but it does it in a memory-efficient way so that # (after unsqueezing the bias), but it does it in a memory-efficient way so that
# it can just store the returned value (chances are, this will also be needed for # it can just store the returned value (chances are, this will also be needed for
# some other reason, related to the next operation, so we can save memory). # some other reason, related to the next operation, so we can save memory).
@staticmethod @staticmethod
@custom_fwd def forward(ctx, x: Tensor, eps: Tensor, scale: Tensor, channel_dim: int,
def forward(ctx, x: Tensor, bias: Tensor, log_scale: Tensor, channel_dim: int,
store_output_for_backprop: bool) -> Tensor: store_output_for_backprop: bool) -> Tensor:
assert bias.ndim == 1
if channel_dim < 0: if channel_dim < 0:
channel_dim = channel_dim + x.ndim channel_dim = channel_dim + x.ndim
ctx.store_output_for_backprop = store_output_for_backprop ctx.store_output_for_backprop = store_output_for_backprop
ctx.channel_dim = channel_dim ctx.channel_dim = channel_dim
for _ in range(channel_dim + 1, x.ndim): scales = ((torch.mean(x ** 2, dim=channel_dim, keepdim=True) + eps) ** -0.5) * scale
bias = bias.unsqueeze(-1)
scales = (torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * log_scale.exp()
ans = x * scales ans = x * scales
ctx.save_for_backward(ans.detach() if store_output_for_backprop else x, ctx.save_for_backward(ans.detach() if store_output_for_backprop else x,
scales.detach(), bias.detach(), log_scale.detach()) scales.detach(), eps.detach(), scale.detach())
return ans return ans
@staticmethod @staticmethod
@custom_bwd
def backward(ctx, ans_grad: Tensor) -> Tensor: def backward(ctx, ans_grad: Tensor) -> Tensor:
ans_or_x, scales, bias, log_scale = ctx.saved_tensors ans_or_x, scales, eps, scale = ctx.saved_tensors
if ctx.store_output_for_backprop: if ctx.store_output_for_backprop:
x = ans_or_x / scales x = ans_or_x / scales
else: else:
x = ans_or_x x = ans_or_x
x = x.detach() with torch.cuda.amp.autocast(enabled=False):
x.requires_grad = True assert eps.dtype != torch.float16 and scale.dtype != torch.float16
bias.requires_grad = True x = x.to(torch.float32).detach()
log_scale.requires_grad = True x.requires_grad = True
with torch.enable_grad(): eps.requires_grad = True
# recompute scales from x, bias and log_scale. scale.requires_grad = True
scales = (torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5) * log_scale.exp() with torch.enable_grad():
ans = x * scales # recompute scales from x, epsand log_scale.
ans.backward(gradient=ans_grad) scales = ((torch.mean(x ** 2, dim=ctx.channel_dim, keepdim=True) + eps) ** -0.5) * scale
return x.grad, bias.grad.flatten(), log_scale.grad, None, None ans = x * scales
ans.backward(gradient=ans_grad.to(torch.float32))
return x.grad.to(ans_grad.dtype), eps.grad, scale.grad, None, None
@ -500,8 +497,11 @@ class BasicNorm(torch.nn.Module):
{-2, -1, 0, 1, 2, 3}. {-2, -1, 0, 1, 2, 3}.
log_scale: the initial log-scale that we multiply the output by; this log_scale: the initial log-scale that we multiply the output by; this
is learnable. is learnable.
eps: the initial epsilon value (not in log space)
log_scale_min: FloatLike, minimum allowed value of log_scale log_scale_min: FloatLike, minimum allowed value of log_scale
log_scale_max: FloatLike, maximum allowed value of log_scale log_scale_max: FloatLike, maximum allowed value of log_scale
log_eps_min: FloatLike, minimum allowed value of log_eps
log_eps_max: FloatLike, maximum allowed value of log_eps
store_output_for_backprop: only possibly affects memory use; recommend store_output_for_backprop: only possibly affects memory use; recommend
to set to True if you think the output of this module is more likely to set to True if you think the output of this module is more likely
than the input of this module to be required to be stored for the than the input of this module to be required to be stored for the
@ -512,18 +512,26 @@ class BasicNorm(torch.nn.Module):
self, self,
num_channels: int, num_channels: int,
channel_dim: int = -1, # CAUTION: see documentation. channel_dim: int = -1, # CAUTION: see documentation.
log_scale: float = 1.0, log_scale: float = 2.0,
log_scale_min: float = -1.5, eps: float = 0.25,
log_scale_max: float = 1.5, log_scale_min: FloatLike = -1.5,
log_scale_max: FloatLike = 1.5,
log_eps_min: FloatLike = -3.0,
log_eps_max: FloatLike = 3.0,
store_output_for_backprop: bool = False store_output_for_backprop: bool = False
) -> None: ) -> None:
super(BasicNorm, self).__init__() super(BasicNorm, self).__init__()
self.num_channels = num_channels self.num_channels = num_channels
self.channel_dim = channel_dim self.channel_dim = channel_dim
self.log_scale = nn.Parameter(torch.tensor(log_scale)) self.log_scale = nn.Parameter(torch.tensor(log_scale))
self.bias = nn.Parameter(torch.zeros(num_channels)) self.log_eps = nn.Parameter(torch.tensor(eps).log().detach())
self.log_scale_min = log_scale_min self.log_scale_min = log_scale_min
self.log_scale_max = log_scale_max self.log_scale_max = log_scale_max
self.log_eps_min = log_eps_min
self.log_eps_max = log_eps_max
self.store_output_for_backprop = store_output_for_backprop self.store_output_for_backprop = store_output_for_backprop
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
@ -537,7 +545,7 @@ class BasicNorm(torch.nn.Module):
bias = self.bias bias = self.bias
for _ in range(channel_dim + 1, x.ndim): for _ in range(channel_dim + 1, x.ndim):
bias = bias.unsqueeze(-1) bias = bias.unsqueeze(-1)
scales = ((torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * scales = (((torch.mean(x ** 2, dim=channel_dim, keepdim=True) + self.eps.exp()) ** -0.5) *
self.log_scale.exp()) self.log_scale.exp())
return x * scales return x * scales
@ -545,8 +553,12 @@ class BasicNorm(torch.nn.Module):
min=float(self.log_scale_min), min=float(self.log_scale_min),
max=float(self.log_scale_max), max=float(self.log_scale_max),
training=self.training) training=self.training)
log_eps = limit_param_value(self.log_eps,
min=float(self.log_eps_min),
max=float(self.log_eps_max),
training=self.training)
return BasicNormFunction.apply(x, self.bias, log_scale, return BasicNormFunction.apply(x, log_eps.exp(), log_scale.exp(),
self.channel_dim, self.channel_dim,
self.store_output_for_backprop) self.store_output_for_backprop)