Re-introduce bias into BasicNorm and replace eps with log_scale.

This commit is contained in:
Daniel Povey 2022-12-26 21:22:00 +08:00
parent 920ed685ac
commit 71d7843654
2 changed files with 102 additions and 41 deletions

View File

@ -430,6 +430,54 @@ class MaxEigLimiterFunction(torch.autograd.Function):
return x_grad + x_extra_grad.detach(), None, None, None, None return x_grad + x_extra_grad.detach(), None, None, None, None
class BasicNormFunction(torch.autograd.Function):
# This computes:
# scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp()
# return (x - bias) * scales
# (after unsqueezing the bias), but it does it in a memory-efficient way so that
# it can just store the returned value (chances are, this will also be needed for
# some other reason, related to the next operation, so we can save memory).
@staticmethod
@custom_fwd
def forward(ctx, x: Tensor, bias: Tensor, log_scale: Tensor, channel_dim: int,
store_output_for_backprop: bool) -> Tensor:
assert bias.ndim == 1
if channel_dim < 0:
channel_dim = channel_dim + x.ndim
ctx.store_output_for_backprop = store_output_for_backprop
ctx.channel_dim = channel_dim
for _ in range(channel_dim + 1, x.ndim):
bias = bias.unsqueeze(-1)
scales = (torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * log_scale.exp()
ans = x * scales
ctx.save_for_backward(ans.detach() if store_output_for_backprop else x,
scales.detach(), bias.detach(), log_scale.detach())
return ans
@staticmethod
@custom_bwd
def backward(ctx, ans_grad: Tensor) -> Tensor:
ans_or_x, scales, bias, log_scale = ctx.saved_tensors
if ctx.store_output_for_backprop:
x = ans_or_x / scales
else:
x = ans_or_x
x = x.detach()
x.requires_grad = True
bias.requires_grad = True
log_scale.requires_grad = True
with torch.enable_grad():
# recompute scales from x, bias and log_scale.
scales = (torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5) * log_scale.exp()
ans = x * scales
ans.backward(gradient=ans_grad)
return x.grad, bias.grad.flatten(), log_scale.grad, None, None
class BasicNorm(torch.nn.Module): class BasicNorm(torch.nn.Module):
""" """
This is intended to be a simpler, and hopefully cheaper, replacement for This is intended to be a simpler, and hopefully cheaper, replacement for
@ -450,47 +498,57 @@ class BasicNorm(torch.nn.Module):
interprted as an offset from the input's ndim if negative. interprted as an offset from the input's ndim if negative.
shis is NOT the num_channels; it should typically be one of shis is NOT the num_channels; it should typically be one of
{-2, -1, 0, 1, 2, 3}. {-2, -1, 0, 1, 2, 3}.
eps: the initial "epsilon" that we add as ballast in: log_scale: the initial log-scale that we multiply the output by; this
scale = ((input_vec**2).mean() + epsilon)**-0.5 is learnable.
Note: our epsilon is actually large, but we keep the name log_scale_min: FloatLike, minimum allowed value of log_scale
to indicate the connection with conventional LayerNorm. log_scale_max: FloatLike, maximum allowed value of log_scale
learn_eps: if true, we learn epsilon; if false, we keep it store_output_for_backprop: only possibly affects memory use; recommend
at the initial value. to set to True if you think the output of this module is more likely
eps_min: float than the input of this module to be required to be stored for the
eps_max: float backprop.
""" """
def __init__( def __init__(
self, self,
num_channels: int, num_channels: int,
channel_dim: int = -1, # CAUTION: see documentation. channel_dim: int = -1, # CAUTION: see documentation.
eps: float = 0.25, log_scale: float = 1.0,
learn_eps: bool = True, log_scale_min: float = -1.5,
eps_min: float = -3.0, log_scale_max: float = 1.5,
eps_max: float = 3.0, store_output_for_backprop: bool = False
) -> None: ) -> None:
super(BasicNorm, self).__init__() super(BasicNorm, self).__init__()
self.num_channels = num_channels self.num_channels = num_channels
self.channel_dim = channel_dim self.channel_dim = channel_dim
if learn_eps: self.log_scale = nn.Parameter(torch.tensor(log_scale))
self.eps = nn.Parameter(torch.tensor(eps).log().detach()) self.bias = nn.Parameter(torch.zeros(num_channels))
else: self.log_scale_min = log_scale_min
self.register_buffer("eps", torch.tensor(eps).log().detach()) self.log_scale_max = log_scale_max
self.eps_min = eps_min self.store_output_for_backprop = store_output_for_backprop
self.eps_max = eps_max
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
assert x.shape[self.channel_dim] == self.num_channels assert x.shape[self.channel_dim] == self.num_channels
eps = self.eps
if self.training: if torch.jit.is_scripting():
eps = limit_param_value(self.eps, min=self.eps_min, max=self.eps_max) channel_dim = self.channel_dim
eps = eps.exp() if channel_dim < 0:
scales = ( channel_dim += x.ndim
(torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps) bias = self.bias
# / (1.0 + eps) for _ in range(channel_dim + 1, x.ndim):
) ** -0.5 bias = bias.unsqueeze(-1)
return x * scales scales = ((torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) *
self.log_scale.exp())
return x * scales
log_scale = limit_param_value(self.log_scale,
min=float(self.log_scale_min),
max=float(self.log_scale_max),
training=self.training)
return BasicNormFunction.apply(x, self.bias, log_scale,
self.channel_dim,
self.store_output_for_backprop)
@ -516,7 +574,8 @@ class PositiveConv1d(nn.Conv1d):
(N, C, H) (N, C, H)
i.e. (batch_size, num_channels, height) i.e. (batch_size, num_channels, height)
""" """
weight = limit_param_value(self.weight, min=float(self.min), max=float(self.max)) weight = limit_param_value(self.weight, min=float(self.min), max=float(self.max),
training=self.training)
# make absolutely sure there are no negative values. For parameter-averaging-related # make absolutely sure there are no negative values. For parameter-averaging-related
# reasons, we prefer to also use limit_param_value to make sure the weights stay # reasons, we prefer to also use limit_param_value to make sure the weights stay
# positive. # positive.
@ -634,7 +693,8 @@ class PositiveConv2d(nn.Conv2d):
(N, C, H, W) (N, C, H, W)
i.e. (batch_size, num_channels, height, width) i.e. (batch_size, num_channels, height, width)
""" """
weight = limit_param_value(self.weight, min=float(self.min), max=float(self.max)) weight = limit_param_value(self.weight, min=float(self.min), max=float(self.max),
training=self.training)
# make absolutely sure there are no negative values. For parameter-averaging-related # make absolutely sure there are no negative values. For parameter-averaging-related
# reasons, we prefer to also use limit_param_value to make sure the weights stay # reasons, we prefer to also use limit_param_value to make sure the weights stay
# positive. # positive.
@ -1156,13 +1216,14 @@ class LimitParamValue(torch.autograd.Function):
def limit_param_value(x: Tensor, def limit_param_value(x: Tensor,
min: float, max: float, min: float, max: float,
prob: float = 0.6): prob: float = 0.6,
training: bool = True):
# You apply this to (typically) an nn.Parameter during training to ensure that its # You apply this to (typically) an nn.Parameter during training to ensure that its
# (elements mostly) stays within a supplied range. This is done by modifying the # (elements mostly) stays within a supplied range. This is done by modifying the
# gradients in backprop. # gradients in backprop.
# It's not necessary to do this on every batch: do it only some of the time, # It's not necessary to do this on every batch: do it only some of the time,
# to save a little time. # to save a little time.
if random.random() < prob: if training and random.random() < prob:
return LimitParamValue.apply(x, min, max) return LimitParamValue.apply(x, min, max)
else: else:
return x return x

View File

@ -453,7 +453,7 @@ class ZipformerEncoderLayer(nn.Module):
self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2) self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2)
self.norm_final = BasicNorm(embed_dim, eps_max=4.0) self.norm_final = BasicNorm(embed_dim)
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
@ -868,11 +868,10 @@ class SimpleCombiner(torch.nn.Module):
dim2 = src2.shape[-1] dim2 = src2.shape[-1]
weight1 = self.weight1 weight1 = limit_param_value(self.weight1,
if self.training: min=self.min_weight[0],
weight1 = limit_param_value(weight1, max=1.0-self.min_weight[1],
min=self.min_weight[0], training=self.training)
max=1.0-self.min_weight[1])
src1_dim = src1.shape[-1] src1_dim = src1.shape[-1]
src2_dim = src2.shape[-1] src2_dim = src2.shape[-1]
@ -1896,7 +1895,8 @@ class Conv2dSubsampling(nn.Module):
x = x * limit_param_value(self.scale, x = x * limit_param_value(self.scale,
min=float(self.scale_min), min=float(self.scale_min),
max=float(self.scale_max)) max=float(self.scale_max),
training=self.training)
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
x = self.out(x) x = self.out(x)