Remove eps from BasicNorm and reintroduce bias

This commit is contained in:
Daniel Povey 2023-01-02 00:02:31 +08:00
parent a2227a07fc
commit 3a5b3f640d
2 changed files with 35 additions and 42 deletions

View File

@ -434,42 +434,44 @@ class MaxEigLimiterFunction(torch.autograd.Function):
class BasicNormFunction(torch.autograd.Function):
# This computes:
# scales = ((torch.mean(x**2, keepdim=True) + eps) ** -0.5 * scale)
# scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp()
# return (x - bias) * scales
# (after unsqueezing the bias), but it does it in a memory-efficient way so that
# it can just store the returned value (chances are, this will also be needed for
# some other reason, related to the next operation, so we can save memory).
@staticmethod
def forward(ctx, x: Tensor, eps: Tensor, scale: Tensor, channel_dim: int,
def forward(ctx, x: Tensor, bias: Tensor, log_scale: Tensor, channel_dim: int,
store_output_for_backprop: bool) -> Tensor:
assert bias.ndim == 1
if channel_dim < 0:
channel_dim = channel_dim + x.ndim
ctx.store_output_for_backprop = store_output_for_backprop
ctx.channel_dim = channel_dim
scales = ((torch.mean(x ** 2, dim=channel_dim, keepdim=True) + eps) ** -0.5) * scale
for _ in range(channel_dim + 1, x.ndim):
bias = bias.unsqueeze(-1)
scales = (torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * log_scale.exp()
ans = x * scales
ctx.save_for_backward(ans.detach() if store_output_for_backprop else x,
scales.detach(), eps.detach(), scale.detach())
scales.detach(), bias.detach(), log_scale.detach())
return ans
@staticmethod
def backward(ctx, ans_grad: Tensor) -> Tensor:
ans_or_x, scales, eps, scale = ctx.saved_tensors
ans_or_x, scales, bias, log_scale = ctx.saved_tensors
if ctx.store_output_for_backprop:
x = ans_or_x / scales
else:
x = ans_or_x
with torch.cuda.amp.autocast(enabled=False):
assert eps.dtype != torch.float16 and scale.dtype != torch.float16
x = x.to(torch.float32).detach()
x.requires_grad = True
eps.requires_grad = True
scale.requires_grad = True
with torch.enable_grad():
# recompute scales from x, eps and log_scale.
scales = ((torch.mean(x ** 2, dim=ctx.channel_dim, keepdim=True) + eps) ** -0.5) * scale
ans = x * scales
ans.backward(gradient=ans_grad.to(torch.float32))
return x.grad.to(ans_grad.dtype), eps.grad, scale.grad, None, None
x = x.detach()
x.requires_grad = True
bias.requires_grad = True
log_scale.requires_grad = True
with torch.enable_grad():
# recompute scales from x, bias and log_scale.
scales = (torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5) * log_scale.exp()
ans = x * scales
ans.backward(gradient=ans_grad)
return x.grad, bias.grad.flatten(), log_scale.grad, None, None
@ -489,17 +491,14 @@ class BasicNorm(torch.nn.Module):
Args:
num_channels: the number of channels, e.g. 512.
channel_dim: the axis/dimension corresponding to the channel,
interprted as an offset from the input's ndim if negative.
shis is NOT the num_channels; it should typically be one of
{-2, -1, 0, 1, 2, 3}.
channel_dim: the axis/dimension corresponding to the channel,
interprted as an offset from the input's ndim if negative.
shis is NOT the num_channels; it should typically be one of
{-2, -1, 0, 1, 2, 3}.
log_scale: the initial log-scale that we multiply the output by; this
is learnable.
eps: the initial epsilon value (not in log space)
log_scale_min: FloatLike, minimum allowed value of log_scale
log_scale_max: FloatLike, maximum allowed value of log_scale
log_eps_min: FloatLike, minimum allowed value of log_eps
log_eps_max: FloatLike, maximum allowed value of log_eps
store_output_for_backprop: only possibly affects memory use; recommend
to set to True if you think the output of this module is more likely
than the input of this module to be required to be stored for the
@ -511,25 +510,19 @@ class BasicNorm(torch.nn.Module):
num_channels: int,
channel_dim: int = -1, # CAUTION: see documentation.
log_scale: float = 1.0,
eps: float = 0.1,
log_scale_min: FloatLike = -1.5,
log_scale_max: FloatLike = 1.5,
log_eps_min: FloatLike = -3.0,
log_eps_max: FloatLike = 3.0,
log_scale_min: float = -1.5,
log_scale_max: float = 1.5,
store_output_for_backprop: bool = False
) -> None:
super(BasicNorm, self).__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim
self.log_scale = nn.Parameter(torch.tensor(log_scale))
self.log_eps = nn.Parameter(torch.tensor(eps).log().detach())
self.bias = nn.Parameter(torch.zeros(num_channels))
self.log_scale_min = log_scale_min
self.log_scale_max = log_scale_max
self.log_eps_min = log_eps_min
self.log_eps_max = log_eps_max
self.store_output_for_backprop = store_output_for_backprop
def forward(self, x: Tensor) -> Tensor:
@ -538,7 +531,12 @@ class BasicNorm(torch.nn.Module):
if torch.jit.is_scripting():
channel_dim = self.channel_dim
scales = (((torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + self.log_eps.exp()) ** -0.5) *
if channel_dim < 0:
channel_dim += x.ndim
bias = self.bias
for _ in range(channel_dim + 1, x.ndim):
bias = bias.unsqueeze(-1)
scales = ((torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) *
self.log_scale.exp())
return x * scales
@ -546,12 +544,8 @@ class BasicNorm(torch.nn.Module):
min=float(self.log_scale_min),
max=float(self.log_scale_max),
training=self.training)
log_eps = limit_param_value(self.log_eps,
min=float(self.log_eps_min),
max=float(self.log_eps_max),
training=self.training)
return BasicNormFunction.apply(x, log_eps.exp(), log_scale.exp(),
return BasicNormFunction.apply(x, self.bias, log_scale,
self.channel_dim,
self.store_output_for_backprop)

View File

@ -1885,8 +1885,7 @@ class Conv2dSubsampling(nn.Module):
# max_log_eps=0.0 is to prevent both eps and the output of self.out from
# getting large, there is an unnecessary degree of freedom.
self.out_norm = BasicNorm(out_channels, eps=1.0,
log_eps_min=-0.1, log_eps_max=0.0)
self.out_norm = BasicNorm(out_channels)
self.dropout = Dropout2(dropout)