Add BasicNorm module

This commit is contained in:
Daniel Povey 2022-03-10 14:32:05 +08:00
parent feb20ca84d
commit 059b57ad37

View File

@ -336,6 +336,83 @@ class DerivBalancerFunction(torch.autograd.Function):
return x_grad - neg_delta_grad, None, None, None, None
class BasicNorm(torch.nn.Module):
"""
This is intended to be a simpler, and hopefully cheaper, replacement for
LayerNorm. The observation this is based on, is that Transformer-type
networks, especially with pre-norm, sometimes seem to set one of the
feature dimensions to a large constant value (e.g. 50), which "defeats"
the LayerNorm because the output magnitude is then not strongly dependent
on the other (useful) features. Presumably the weight and bias of the
LayerNorm are required to allow it to do this.
So the idea is to introduce this large constant value as an explicit
parameter, that takes the role of the "eps" in LayerNorm, so the network
doesn't have to do this trick.
We also introduce a learned scaling factor on the output; and we
remove the subtracting-the-mean aspect of LayerNorm (which anyway, is not
that useful unless the LayerNorm immediately follows a nonlinearity).
Args:
channel_dim: the axis/dimension corresponding to the channel,
interprted as an offset from the input's ndim if negative.
shis is NOT the num_channels; it should typically be one of
{-2, -1, 0, 1, 2, 3}.
initial_eps_scale: a constant that determines the initial
"epsilon" that we add as ballast in:
scale = output_scale * ((input_vec**2).sum() + epsilon)**-0.5
Note: our epsilon is actually large, not small, but we keep the name
to indicate the connection with normal LayerNorm. We set
epsilon initially to num_channels * initial_eps_scale.
speed: a scaling factor that can be interpreted as scaling the learning
rate for this module. CAUTION: the default value of 10.0 intended to be
used with Adam or amsgrad-type optimizers, e.g. Adam or Noam.
If you are using SGD you would probably have to set `speed` to
a value less than one, or the training would be unstable.
"""
def __init__(self,
num_channels: int,
channel_dim: int = -1, # CAUTION: see documentation.
initial_eps_scale: float = 0.25,
speed: float = 10.0):
super(BasicNorm, self).__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim
self.speed = speed
eps = num_channels * initial_eps_scale
# log_eps = log(eps) / speed
log_eps = torch.tensor(eps).log() / speed
self.log_eps = nn.Parameter(log_eps.detach())
# initial output-scale, to get LayerNorm-like behavior, is
# sqrt(num_channels).
initial_scale = torch.tensor(num_channels ** 0.5).log() / speed
self.log_scale = nn.Parameter(initial_scale.detach())
def _inner(self, x: Tensor) -> Tensor:
# inner product on last dim of x, keeping the dimension,
# i.e. torch.sum(x**2, dim=-1, keepdim=True), but more
# efficient.
if hasattr(torch, 'inner'):
return torch.inner(x).unsqueeze(-1)
else:
# TODO: we can do this with matrix multiplication, maybe.a
return torch.sum(x**2, dim=-1, keepdim=True)
def forward(self, x: Tensor) -> Tensor:
assert x.shape[self.channel_dim] == self.num_channels
x = x.transpose(-1, self.channel_dim)
eps = (self.log_eps * self.speed).exp()
out_scale = (self.log_scale * self.speed).exp()
scales = out_scale * (self._inner(x) + eps) ** -0.5
x = x * scales
x = x.transpose(-1, self.channel_dim)
return x
class DerivBalancer(torch.nn.Module):
"""
@ -367,16 +444,16 @@ class DerivBalancer(torch.nn.Module):
def _test_exp_scale_swish():
class Swish(torch.nn.Module):
class DoubleSwish(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Return Swich activation function."""
return x * torch.sigmoid(x)
return x * torch.sigmoid(x - 1.0)
x1 = torch.randn(50, 60).detach()
x2 = x1.detach()
m1 = ExpScaleSwish(50, 1, speed=4.0)
m2 = torch.nn.Sequential(Swish(), ExpScale(50, 1, speed=4.0))
m2 = torch.nn.Sequential(DoubleSwish(), ExpScale(50, 1, speed=4.0))
x1.requires_grad = True
x2.requires_grad = True
@ -425,8 +502,26 @@ def _test_deriv_balancer():
print("x grad = ", x.grad)
def _test_basic_norm():
num_channels = 128
m = BasicNorm(num_channels=num_channels, channel_dim=1)
x = torch.randn(500, num_channels)
y = m(x)
assert y.shape == x.shape
x_rms = (x**2).mean().sqrt()
y_rms = (y**2).mean().sqrt()
print("x rms = ", x_rms)
print("y rms = ", y_rms)
assert y_rms < x_rms
assert y_rms > 0.5 * x_rms
if __name__ == '__main__':
_test_deriv_balancer()
_test_exp_scale_swish()
_test_exp_scale_relu()
_test_basic_norm()