mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Add BasicNorm module
This commit is contained in:
parent
feb20ca84d
commit
059b57ad37
@ -336,6 +336,83 @@ class DerivBalancerFunction(torch.autograd.Function):
|
||||
return x_grad - neg_delta_grad, None, None, None, None
|
||||
|
||||
|
||||
class BasicNorm(torch.nn.Module):
|
||||
"""
|
||||
This is intended to be a simpler, and hopefully cheaper, replacement for
|
||||
LayerNorm. The observation this is based on, is that Transformer-type
|
||||
networks, especially with pre-norm, sometimes seem to set one of the
|
||||
feature dimensions to a large constant value (e.g. 50), which "defeats"
|
||||
the LayerNorm because the output magnitude is then not strongly dependent
|
||||
on the other (useful) features. Presumably the weight and bias of the
|
||||
LayerNorm are required to allow it to do this.
|
||||
|
||||
So the idea is to introduce this large constant value as an explicit
|
||||
parameter, that takes the role of the "eps" in LayerNorm, so the network
|
||||
doesn't have to do this trick.
|
||||
|
||||
We also introduce a learned scaling factor on the output; and we
|
||||
remove the subtracting-the-mean aspect of LayerNorm (which anyway, is not
|
||||
that useful unless the LayerNorm immediately follows a nonlinearity).
|
||||
|
||||
|
||||
Args:
|
||||
channel_dim: the axis/dimension corresponding to the channel,
|
||||
interprted as an offset from the input's ndim if negative.
|
||||
shis is NOT the num_channels; it should typically be one of
|
||||
{-2, -1, 0, 1, 2, 3}.
|
||||
initial_eps_scale: a constant that determines the initial
|
||||
"epsilon" that we add as ballast in:
|
||||
scale = output_scale * ((input_vec**2).sum() + epsilon)**-0.5
|
||||
Note: our epsilon is actually large, not small, but we keep the name
|
||||
to indicate the connection with normal LayerNorm. We set
|
||||
epsilon initially to num_channels * initial_eps_scale.
|
||||
speed: a scaling factor that can be interpreted as scaling the learning
|
||||
rate for this module. CAUTION: the default value of 10.0 intended to be
|
||||
used with Adam or amsgrad-type optimizers, e.g. Adam or Noam.
|
||||
If you are using SGD you would probably have to set `speed` to
|
||||
a value less than one, or the training would be unstable.
|
||||
"""
|
||||
def __init__(self,
|
||||
num_channels: int,
|
||||
channel_dim: int = -1, # CAUTION: see documentation.
|
||||
initial_eps_scale: float = 0.25,
|
||||
speed: float = 10.0):
|
||||
super(BasicNorm, self).__init__()
|
||||
self.num_channels = num_channels
|
||||
self.channel_dim = channel_dim
|
||||
self.speed = speed
|
||||
eps = num_channels * initial_eps_scale
|
||||
# log_eps = log(eps) / speed
|
||||
log_eps = torch.tensor(eps).log() / speed
|
||||
self.log_eps = nn.Parameter(log_eps.detach())
|
||||
# initial output-scale, to get LayerNorm-like behavior, is
|
||||
# sqrt(num_channels).
|
||||
initial_scale = torch.tensor(num_channels ** 0.5).log() / speed
|
||||
self.log_scale = nn.Parameter(initial_scale.detach())
|
||||
|
||||
def _inner(self, x: Tensor) -> Tensor:
|
||||
# inner product on last dim of x, keeping the dimension,
|
||||
# i.e. torch.sum(x**2, dim=-1, keepdim=True), but more
|
||||
# efficient.
|
||||
if hasattr(torch, 'inner'):
|
||||
return torch.inner(x).unsqueeze(-1)
|
||||
else:
|
||||
# TODO: we can do this with matrix multiplication, maybe.a
|
||||
return torch.sum(x**2, dim=-1, keepdim=True)
|
||||
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
assert x.shape[self.channel_dim] == self.num_channels
|
||||
x = x.transpose(-1, self.channel_dim)
|
||||
eps = (self.log_eps * self.speed).exp()
|
||||
out_scale = (self.log_scale * self.speed).exp()
|
||||
|
||||
scales = out_scale * (self._inner(x) + eps) ** -0.5
|
||||
x = x * scales
|
||||
x = x.transpose(-1, self.channel_dim)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class DerivBalancer(torch.nn.Module):
|
||||
"""
|
||||
@ -367,16 +444,16 @@ class DerivBalancer(torch.nn.Module):
|
||||
|
||||
|
||||
def _test_exp_scale_swish():
|
||||
class Swish(torch.nn.Module):
|
||||
class DoubleSwish(torch.nn.Module):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Return Swich activation function."""
|
||||
return x * torch.sigmoid(x)
|
||||
return x * torch.sigmoid(x - 1.0)
|
||||
|
||||
x1 = torch.randn(50, 60).detach()
|
||||
x2 = x1.detach()
|
||||
|
||||
m1 = ExpScaleSwish(50, 1, speed=4.0)
|
||||
m2 = torch.nn.Sequential(Swish(), ExpScale(50, 1, speed=4.0))
|
||||
m2 = torch.nn.Sequential(DoubleSwish(), ExpScale(50, 1, speed=4.0))
|
||||
x1.requires_grad = True
|
||||
x2.requires_grad = True
|
||||
|
||||
@ -425,8 +502,26 @@ def _test_deriv_balancer():
|
||||
print("x grad = ", x.grad)
|
||||
|
||||
|
||||
def _test_basic_norm():
|
||||
num_channels = 128
|
||||
m = BasicNorm(num_channels=num_channels, channel_dim=1)
|
||||
|
||||
x = torch.randn(500, num_channels)
|
||||
|
||||
y = m(x)
|
||||
|
||||
assert y.shape == x.shape
|
||||
x_rms = (x**2).mean().sqrt()
|
||||
y_rms = (y**2).mean().sqrt()
|
||||
print("x rms = ", x_rms)
|
||||
print("y rms = ", y_rms)
|
||||
assert y_rms < x_rms
|
||||
assert y_rms > 0.5 * x_rms
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_test_deriv_balancer()
|
||||
_test_exp_scale_swish()
|
||||
_test_exp_scale_relu()
|
||||
_test_basic_norm()
|
||||
|
Loading…
x
Reference in New Issue
Block a user