diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 2df2678dd..622495f21 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -336,6 +336,83 @@ class DerivBalancerFunction(torch.autograd.Function): return x_grad - neg_delta_grad, None, None, None, None +class BasicNorm(torch.nn.Module): + """ + This is intended to be a simpler, and hopefully cheaper, replacement for + LayerNorm. The observation this is based on, is that Transformer-type + networks, especially with pre-norm, sometimes seem to set one of the + feature dimensions to a large constant value (e.g. 50), which "defeats" + the LayerNorm because the output magnitude is then not strongly dependent + on the other (useful) features. Presumably the weight and bias of the + LayerNorm are required to allow it to do this. + + So the idea is to introduce this large constant value as an explicit + parameter, that takes the role of the "eps" in LayerNorm, so the network + doesn't have to do this trick. + + We also introduce a learned scaling factor on the output; and we + remove the subtracting-the-mean aspect of LayerNorm (which anyway, is not + that useful unless the LayerNorm immediately follows a nonlinearity). + + + Args: + channel_dim: the axis/dimension corresponding to the channel, + interprted as an offset from the input's ndim if negative. + shis is NOT the num_channels; it should typically be one of + {-2, -1, 0, 1, 2, 3}. + initial_eps_scale: a constant that determines the initial + "epsilon" that we add as ballast in: + scale = output_scale * ((input_vec**2).sum() + epsilon)**-0.5 + Note: our epsilon is actually large, not small, but we keep the name + to indicate the connection with normal LayerNorm. We set + epsilon initially to num_channels * initial_eps_scale. + speed: a scaling factor that can be interpreted as scaling the learning + rate for this module. CAUTION: the default value of 10.0 intended to be + used with Adam or amsgrad-type optimizers, e.g. Adam or Noam. + If you are using SGD you would probably have to set `speed` to + a value less than one, or the training would be unstable. + """ + def __init__(self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + initial_eps_scale: float = 0.25, + speed: float = 10.0): + super(BasicNorm, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + self.speed = speed + eps = num_channels * initial_eps_scale + # log_eps = log(eps) / speed + log_eps = torch.tensor(eps).log() / speed + self.log_eps = nn.Parameter(log_eps.detach()) + # initial output-scale, to get LayerNorm-like behavior, is + # sqrt(num_channels). + initial_scale = torch.tensor(num_channels ** 0.5).log() / speed + self.log_scale = nn.Parameter(initial_scale.detach()) + + def _inner(self, x: Tensor) -> Tensor: + # inner product on last dim of x, keeping the dimension, + # i.e. torch.sum(x**2, dim=-1, keepdim=True), but more + # efficient. + if hasattr(torch, 'inner'): + return torch.inner(x).unsqueeze(-1) + else: + # TODO: we can do this with matrix multiplication, maybe.a + return torch.sum(x**2, dim=-1, keepdim=True) + + + def forward(self, x: Tensor) -> Tensor: + assert x.shape[self.channel_dim] == self.num_channels + x = x.transpose(-1, self.channel_dim) + eps = (self.log_eps * self.speed).exp() + out_scale = (self.log_scale * self.speed).exp() + + scales = out_scale * (self._inner(x) + eps) ** -0.5 + x = x * scales + x = x.transpose(-1, self.channel_dim) + return x + + class DerivBalancer(torch.nn.Module): """ @@ -367,16 +444,16 @@ class DerivBalancer(torch.nn.Module): def _test_exp_scale_swish(): - class Swish(torch.nn.Module): + class DoubleSwish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: """Return Swich activation function.""" - return x * torch.sigmoid(x) + return x * torch.sigmoid(x - 1.0) x1 = torch.randn(50, 60).detach() x2 = x1.detach() m1 = ExpScaleSwish(50, 1, speed=4.0) - m2 = torch.nn.Sequential(Swish(), ExpScale(50, 1, speed=4.0)) + m2 = torch.nn.Sequential(DoubleSwish(), ExpScale(50, 1, speed=4.0)) x1.requires_grad = True x2.requires_grad = True @@ -425,8 +502,26 @@ def _test_deriv_balancer(): print("x grad = ", x.grad) +def _test_basic_norm(): + num_channels = 128 + m = BasicNorm(num_channels=num_channels, channel_dim=1) + + x = torch.randn(500, num_channels) + + y = m(x) + + assert y.shape == x.shape + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() + print("x rms = ", x_rms) + print("y rms = ", y_rms) + assert y_rms < x_rms + assert y_rms > 0.5 * x_rms + + if __name__ == '__main__': _test_deriv_balancer() _test_exp_scale_swish() _test_exp_scale_relu() + _test_basic_norm()