Revert BasicNorm to its previous status, without the bias

This commit is contained in:
Daniel Povey 2022-12-22 23:47:21 +08:00
parent b2125535fb
commit ade7db54e3
2 changed files with 7 additions and 67 deletions

View File

@ -430,51 +430,6 @@ class MaxEigLimiterFunction(torch.autograd.Function):
return x_grad + x_extra_grad.detach(), None, None, None, None return x_grad + x_extra_grad.detach(), None, None, None, None
class BasicNormFunction(torch.autograd.Function):
# This computes:
# scales = torch.mean((x - bias) ** 2, keepdim=True) + eps.exp()
# return (x - bias) * scales
# (after unsqueezing the bias), but it does it in a memory-efficient way so that
# it can just store the returned value (chances are, this will also be needed for
# some other reason, related to the next operation, so we can save memory).
@staticmethod
@custom_fwd
def forward(ctx, x: Tensor, bias: Tensor, eps: Tensor, channel_dim: int,
store_output_for_backprop: bool) -> Tensor:
assert bias.ndim == 1
if channel_dim < 0:
channel_dim = channel_dim + x.ndim
ctx.store_output_for_backprop = store_output_for_backprop
ctx.channel_dim = channel_dim
for _ in range(channel_dim + 1, x.ndim):
bias = bias.unsqueeze(-1)
scales = (torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5
ans = x * scales
ctx.save_for_backward(ans.detach() if store_output_for_backprop else x,
scales.detach(), bias.detach(), eps.detach())
return ans
@staticmethod
@custom_bwd
def backward(ctx, ans_grad: Tensor) -> Tensor:
ans_or_x, scales, bias, eps = ctx.saved_tensors
if ctx.store_output_for_backprop:
x = ans_or_x / scales
else:
x = ans_or_x
x = x.detach()
x.requires_grad = True
bias.requires_grad = True
eps.requires_grad = True
with torch.enable_grad():
# recompute scales from x, bias and eps.
scales = (torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) + eps.exp()) ** -0.5
ans = x * scales
ans.backward(gradient=ans_grad)
return x.grad, bias.grad.flatten(), eps.grad, None, None
class BasicNorm(torch.nn.Module): class BasicNorm(torch.nn.Module):
""" """
This is intended to be a simpler, and hopefully cheaper, replacement for This is intended to be a simpler, and hopefully cheaper, replacement for
@ -491,7 +446,7 @@ class BasicNorm(torch.nn.Module):
Args: Args:
num_channels: the number of channels, e.g. 512. num_channels: the number of channels, e.g. 512.
channel_dim: the axis/dimension corresponding to the channel, channel_dim: the axis/dimension corresponding to the channel,
interprted as an offset from the input's ndim if negative. interprted as an offset from the input's ndim if negative.
shis is NOT the num_channels; it should typically be one of shis is NOT the num_channels; it should typically be one of
{-2, -1, 0, 1, 2, 3}. {-2, -1, 0, 1, 2, 3}.
@ -501,13 +456,10 @@ class BasicNorm(torch.nn.Module):
to indicate the connection with conventional LayerNorm. to indicate the connection with conventional LayerNorm.
learn_eps: if true, we learn epsilon; if false, we keep it learn_eps: if true, we learn epsilon; if false, we keep it
at the initial value. at the initial value.
store_output_for_backprop: this option makes no difference
to the output, but may affect memory usage; determines
whether, for backprop purposes, we store the input or the output
of this module.
eps_min: float eps_min: float
eps_max: float eps_max: float
""" """
def __init__( def __init__(
self, self,
num_channels: int, num_channels: int,
@ -516,7 +468,6 @@ class BasicNorm(torch.nn.Module):
learn_eps: bool = True, learn_eps: bool = True,
eps_min: float = -3.0, eps_min: float = -3.0,
eps_max: float = 3.0, eps_max: float = 3.0,
store_output_for_backprop: bool = True
) -> None: ) -> None:
super(BasicNorm, self).__init__() super(BasicNorm, self).__init__()
self.num_channels = num_channels self.num_channels = num_channels
@ -525,24 +476,12 @@ class BasicNorm(torch.nn.Module):
self.eps = nn.Parameter(torch.tensor(eps).log().detach()) self.eps = nn.Parameter(torch.tensor(eps).log().detach())
else: else:
self.register_buffer("eps", torch.tensor(eps).log().detach()) self.register_buffer("eps", torch.tensor(eps).log().detach())
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps_min = eps_min self.eps_min = eps_min
self.eps_max = eps_max self.eps_max = eps_max
self.store_output_for_backprop = store_output_for_backprop
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
assert x.shape[self.channel_dim] == self.num_channels assert x.shape[self.channel_dim] == self.num_channels
eps = self.eps eps = self.eps
if torch.jit.is_scripting():
channel_dim = self.channel_dim
if channel_dim < 0:
channel_dim = channel_dim + x.ndim
for _ in range(channel_dim + 1, x.ndim):
bias = bias.unsqueeze(-1)
scales = (torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5
return x * scales
if self.training and random.random() < 0.25: if self.training and random.random() < 0.25:
# with probability 0.25, in training mode, clamp eps between the min # with probability 0.25, in training mode, clamp eps between the min
# and max; this will encourage it to learn parameters within the # and max; this will encourage it to learn parameters within the
@ -552,9 +491,10 @@ class BasicNorm(torch.nn.Module):
# gradients to allow the parameter to get back into the allowed # gradients to allow the parameter to get back into the allowed
# region if it happens to exit it. # region if it happens to exit it.
eps = eps.clamp(min=self.eps_min, max=self.eps_max) eps = eps.clamp(min=self.eps_min, max=self.eps_max)
scales = (
return BasicNormFunction.apply(x, self.bias, eps, self.channel_dim, torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp()
self.store_output_for_backprop) ) ** -0.5
return x * scales

View File

@ -451,7 +451,7 @@ class ZipformerEncoderLayer(nn.Module):
self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2) self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2)
self.norm_final = BasicNorm(embed_dim, store_output_for_backprop=False) self.norm_final = BasicNorm(embed_dim)
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))