Re-introduce bias into BasicNorm and replace eps with log_scale.

This commit is contained in:
Daniel Povey 2022-12-26 21:22:00 +08:00
parent 920ed685ac
commit 71d7843654
2 changed files with 102 additions and 41 deletions

View File

@ -430,6 +430,54 @@ class MaxEigLimiterFunction(torch.autograd.Function):
return x_grad + x_extra_grad.detach(), None, None, None, None
class BasicNormFunction(torch.autograd.Function):
# This computes:
# scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp()
# return (x - bias) * scales
# (after unsqueezing the bias), but it does it in a memory-efficient way so that
# it can just store the returned value (chances are, this will also be needed for
# some other reason, related to the next operation, so we can save memory).
@staticmethod
@custom_fwd
def forward(ctx, x: Tensor, bias: Tensor, log_scale: Tensor, channel_dim: int,
store_output_for_backprop: bool) -> Tensor:
assert bias.ndim == 1
if channel_dim < 0:
channel_dim = channel_dim + x.ndim
ctx.store_output_for_backprop = store_output_for_backprop
ctx.channel_dim = channel_dim
for _ in range(channel_dim + 1, x.ndim):
bias = bias.unsqueeze(-1)
scales = (torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * log_scale.exp()
ans = x * scales
ctx.save_for_backward(ans.detach() if store_output_for_backprop else x,
scales.detach(), bias.detach(), log_scale.detach())
return ans
@staticmethod
@custom_bwd
def backward(ctx, ans_grad: Tensor) -> Tensor:
ans_or_x, scales, bias, log_scale = ctx.saved_tensors
if ctx.store_output_for_backprop:
x = ans_or_x / scales
else:
x = ans_or_x
x = x.detach()
x.requires_grad = True
bias.requires_grad = True
log_scale.requires_grad = True
with torch.enable_grad():
# recompute scales from x, bias and log_scale.
scales = (torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5) * log_scale.exp()
ans = x * scales
ans.backward(gradient=ans_grad)
return x.grad, bias.grad.flatten(), log_scale.grad, None, None
class BasicNorm(torch.nn.Module):
"""
This is intended to be a simpler, and hopefully cheaper, replacement for
@ -450,47 +498,57 @@ class BasicNorm(torch.nn.Module):
interprted as an offset from the input's ndim if negative.
shis is NOT the num_channels; it should typically be one of
{-2, -1, 0, 1, 2, 3}.
eps: the initial "epsilon" that we add as ballast in:
scale = ((input_vec**2).mean() + epsilon)**-0.5
Note: our epsilon is actually large, but we keep the name
to indicate the connection with conventional LayerNorm.
learn_eps: if true, we learn epsilon; if false, we keep it
at the initial value.
eps_min: float
eps_max: float
log_scale: the initial log-scale that we multiply the output by; this
is learnable.
log_scale_min: FloatLike, minimum allowed value of log_scale
log_scale_max: FloatLike, maximum allowed value of log_scale
store_output_for_backprop: only possibly affects memory use; recommend
to set to True if you think the output of this module is more likely
than the input of this module to be required to be stored for the
backprop.
"""
def __init__(
self,
num_channels: int,
channel_dim: int = -1, # CAUTION: see documentation.
eps: float = 0.25,
learn_eps: bool = True,
eps_min: float = -3.0,
eps_max: float = 3.0,
self,
num_channels: int,
channel_dim: int = -1, # CAUTION: see documentation.
log_scale: float = 1.0,
log_scale_min: float = -1.5,
log_scale_max: float = 1.5,
store_output_for_backprop: bool = False
) -> None:
super(BasicNorm, self).__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim
if learn_eps:
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
else:
self.register_buffer("eps", torch.tensor(eps).log().detach())
self.eps_min = eps_min
self.eps_max = eps_max
self.log_scale = nn.Parameter(torch.tensor(log_scale))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.log_scale_min = log_scale_min
self.log_scale_max = log_scale_max
self.store_output_for_backprop = store_output_for_backprop
def forward(self, x: Tensor) -> Tensor:
assert x.shape[self.channel_dim] == self.num_channels
eps = self.eps
if self.training:
eps = limit_param_value(self.eps, min=self.eps_min, max=self.eps_max)
eps = eps.exp()
scales = (
(torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps)
# / (1.0 + eps)
) ** -0.5
return x * scales
if torch.jit.is_scripting():
channel_dim = self.channel_dim
if channel_dim < 0:
channel_dim += x.ndim
bias = self.bias
for _ in range(channel_dim + 1, x.ndim):
bias = bias.unsqueeze(-1)
scales = ((torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) *
self.log_scale.exp())
return x * scales
log_scale = limit_param_value(self.log_scale,
min=float(self.log_scale_min),
max=float(self.log_scale_max),
training=self.training)
return BasicNormFunction.apply(x, self.bias, log_scale,
self.channel_dim,
self.store_output_for_backprop)
@ -516,7 +574,8 @@ class PositiveConv1d(nn.Conv1d):
(N, C, H)
i.e. (batch_size, num_channels, height)
"""
weight = limit_param_value(self.weight, min=float(self.min), max=float(self.max))
weight = limit_param_value(self.weight, min=float(self.min), max=float(self.max),
training=self.training)
# make absolutely sure there are no negative values. For parameter-averaging-related
# reasons, we prefer to also use limit_param_value to make sure the weights stay
# positive.
@ -634,7 +693,8 @@ class PositiveConv2d(nn.Conv2d):
(N, C, H, W)
i.e. (batch_size, num_channels, height, width)
"""
weight = limit_param_value(self.weight, min=float(self.min), max=float(self.max))
weight = limit_param_value(self.weight, min=float(self.min), max=float(self.max),
training=self.training)
# make absolutely sure there are no negative values. For parameter-averaging-related
# reasons, we prefer to also use limit_param_value to make sure the weights stay
# positive.
@ -1156,13 +1216,14 @@ class LimitParamValue(torch.autograd.Function):
def limit_param_value(x: Tensor,
min: float, max: float,
prob: float = 0.6):
prob: float = 0.6,
training: bool = True):
# You apply this to (typically) an nn.Parameter during training to ensure that its
# (elements mostly) stays within a supplied range. This is done by modifying the
# gradients in backprop.
# It's not necessary to do this on every batch: do it only some of the time,
# to save a little time.
if random.random() < prob:
if training and random.random() < prob:
return LimitParamValue.apply(x, min, max)
else:
return x

View File

@ -453,7 +453,7 @@ class ZipformerEncoderLayer(nn.Module):
self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2)
self.norm_final = BasicNorm(embed_dim, eps_max=4.0)
self.norm_final = BasicNorm(embed_dim)
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
@ -868,11 +868,10 @@ class SimpleCombiner(torch.nn.Module):
dim2 = src2.shape[-1]
weight1 = self.weight1
if self.training:
weight1 = limit_param_value(weight1,
min=self.min_weight[0],
max=1.0-self.min_weight[1])
weight1 = limit_param_value(self.weight1,
min=self.min_weight[0],
max=1.0-self.min_weight[1],
training=self.training)
src1_dim = src1.shape[-1]
src2_dim = src2.shape[-1]
@ -1896,7 +1895,8 @@ class Conv2dSubsampling(nn.Module):
x = x * limit_param_value(self.scale,
min=float(self.scale_min),
max=float(self.scale_max))
max=float(self.scale_max),
training=self.training)
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
x = self.out(x)