mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
Add BasicNorm module
This commit is contained in:
parent
feb20ca84d
commit
059b57ad37
@ -336,6 +336,83 @@ class DerivBalancerFunction(torch.autograd.Function):
|
|||||||
return x_grad - neg_delta_grad, None, None, None, None
|
return x_grad - neg_delta_grad, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
class BasicNorm(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
This is intended to be a simpler, and hopefully cheaper, replacement for
|
||||||
|
LayerNorm. The observation this is based on, is that Transformer-type
|
||||||
|
networks, especially with pre-norm, sometimes seem to set one of the
|
||||||
|
feature dimensions to a large constant value (e.g. 50), which "defeats"
|
||||||
|
the LayerNorm because the output magnitude is then not strongly dependent
|
||||||
|
on the other (useful) features. Presumably the weight and bias of the
|
||||||
|
LayerNorm are required to allow it to do this.
|
||||||
|
|
||||||
|
So the idea is to introduce this large constant value as an explicit
|
||||||
|
parameter, that takes the role of the "eps" in LayerNorm, so the network
|
||||||
|
doesn't have to do this trick.
|
||||||
|
|
||||||
|
We also introduce a learned scaling factor on the output; and we
|
||||||
|
remove the subtracting-the-mean aspect of LayerNorm (which anyway, is not
|
||||||
|
that useful unless the LayerNorm immediately follows a nonlinearity).
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel_dim: the axis/dimension corresponding to the channel,
|
||||||
|
interprted as an offset from the input's ndim if negative.
|
||||||
|
shis is NOT the num_channels; it should typically be one of
|
||||||
|
{-2, -1, 0, 1, 2, 3}.
|
||||||
|
initial_eps_scale: a constant that determines the initial
|
||||||
|
"epsilon" that we add as ballast in:
|
||||||
|
scale = output_scale * ((input_vec**2).sum() + epsilon)**-0.5
|
||||||
|
Note: our epsilon is actually large, not small, but we keep the name
|
||||||
|
to indicate the connection with normal LayerNorm. We set
|
||||||
|
epsilon initially to num_channels * initial_eps_scale.
|
||||||
|
speed: a scaling factor that can be interpreted as scaling the learning
|
||||||
|
rate for this module. CAUTION: the default value of 10.0 intended to be
|
||||||
|
used with Adam or amsgrad-type optimizers, e.g. Adam or Noam.
|
||||||
|
If you are using SGD you would probably have to set `speed` to
|
||||||
|
a value less than one, or the training would be unstable.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
num_channels: int,
|
||||||
|
channel_dim: int = -1, # CAUTION: see documentation.
|
||||||
|
initial_eps_scale: float = 0.25,
|
||||||
|
speed: float = 10.0):
|
||||||
|
super(BasicNorm, self).__init__()
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.channel_dim = channel_dim
|
||||||
|
self.speed = speed
|
||||||
|
eps = num_channels * initial_eps_scale
|
||||||
|
# log_eps = log(eps) / speed
|
||||||
|
log_eps = torch.tensor(eps).log() / speed
|
||||||
|
self.log_eps = nn.Parameter(log_eps.detach())
|
||||||
|
# initial output-scale, to get LayerNorm-like behavior, is
|
||||||
|
# sqrt(num_channels).
|
||||||
|
initial_scale = torch.tensor(num_channels ** 0.5).log() / speed
|
||||||
|
self.log_scale = nn.Parameter(initial_scale.detach())
|
||||||
|
|
||||||
|
def _inner(self, x: Tensor) -> Tensor:
|
||||||
|
# inner product on last dim of x, keeping the dimension,
|
||||||
|
# i.e. torch.sum(x**2, dim=-1, keepdim=True), but more
|
||||||
|
# efficient.
|
||||||
|
if hasattr(torch, 'inner'):
|
||||||
|
return torch.inner(x).unsqueeze(-1)
|
||||||
|
else:
|
||||||
|
# TODO: we can do this with matrix multiplication, maybe.a
|
||||||
|
return torch.sum(x**2, dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
assert x.shape[self.channel_dim] == self.num_channels
|
||||||
|
x = x.transpose(-1, self.channel_dim)
|
||||||
|
eps = (self.log_eps * self.speed).exp()
|
||||||
|
out_scale = (self.log_scale * self.speed).exp()
|
||||||
|
|
||||||
|
scales = out_scale * (self._inner(x) + eps) ** -0.5
|
||||||
|
x = x * scales
|
||||||
|
x = x.transpose(-1, self.channel_dim)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class DerivBalancer(torch.nn.Module):
|
class DerivBalancer(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -367,16 +444,16 @@ class DerivBalancer(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def _test_exp_scale_swish():
|
def _test_exp_scale_swish():
|
||||||
class Swish(torch.nn.Module):
|
class DoubleSwish(torch.nn.Module):
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
"""Return Swich activation function."""
|
"""Return Swich activation function."""
|
||||||
return x * torch.sigmoid(x)
|
return x * torch.sigmoid(x - 1.0)
|
||||||
|
|
||||||
x1 = torch.randn(50, 60).detach()
|
x1 = torch.randn(50, 60).detach()
|
||||||
x2 = x1.detach()
|
x2 = x1.detach()
|
||||||
|
|
||||||
m1 = ExpScaleSwish(50, 1, speed=4.0)
|
m1 = ExpScaleSwish(50, 1, speed=4.0)
|
||||||
m2 = torch.nn.Sequential(Swish(), ExpScale(50, 1, speed=4.0))
|
m2 = torch.nn.Sequential(DoubleSwish(), ExpScale(50, 1, speed=4.0))
|
||||||
x1.requires_grad = True
|
x1.requires_grad = True
|
||||||
x2.requires_grad = True
|
x2.requires_grad = True
|
||||||
|
|
||||||
@ -425,8 +502,26 @@ def _test_deriv_balancer():
|
|||||||
print("x grad = ", x.grad)
|
print("x grad = ", x.grad)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_basic_norm():
|
||||||
|
num_channels = 128
|
||||||
|
m = BasicNorm(num_channels=num_channels, channel_dim=1)
|
||||||
|
|
||||||
|
x = torch.randn(500, num_channels)
|
||||||
|
|
||||||
|
y = m(x)
|
||||||
|
|
||||||
|
assert y.shape == x.shape
|
||||||
|
x_rms = (x**2).mean().sqrt()
|
||||||
|
y_rms = (y**2).mean().sqrt()
|
||||||
|
print("x rms = ", x_rms)
|
||||||
|
print("y rms = ", y_rms)
|
||||||
|
assert y_rms < x_rms
|
||||||
|
assert y_rms > 0.5 * x_rms
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
_test_deriv_balancer()
|
_test_deriv_balancer()
|
||||||
_test_exp_scale_swish()
|
_test_exp_scale_swish()
|
||||||
_test_exp_scale_relu()
|
_test_exp_scale_relu()
|
||||||
|
_test_basic_norm()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user