From fc873cc50d7e5a72344b0f081e93802acb441a73 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 15 Mar 2022 17:00:17 +0800 Subject: [PATCH] Make epsilon in BasicNorm learnable, optionally. --- .../ASR/conformer_ctc/subsampling.py | 44 +++++++++++-------- .../ASR/transducer_stateless/conformer.py | 3 +- .../ASR/transducer_stateless/train.py | 2 +- 3 files changed, 28 insertions(+), 21 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 35de71e43..78fcac664 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -56,7 +56,10 @@ class Conv2dSubsampling(nn.Module): DoubleSwish(), ) self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim) - self.out_norm = BasicNorm(odim) + # set learn_eps=False because out_norm is preceded by `out`, and `out` + # itself has learned scale, so the extra degree of freedom is not + # needed. + self.out_norm = BasicNorm(odim, learn_eps=False) # constrain mean of output to be close to zero. self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.4, max_positive=0.6) self._reset_parameters() @@ -361,42 +364,45 @@ class BasicNorm(torch.nn.Module): So the idea is to introduce this large constant value as an explicit parameter, that takes the role of the "eps" in LayerNorm, so the network - doesn't have to do this trick. - - We also introduce a learned scaling factor on the output; and we - remove the subtracting-the-mean aspect of LayerNorm (which anyway, is not - that useful unless the LayerNorm immediately follows a nonlinearity). - + doesn't have to do this trick. We make the "eps" learnable. Args: + num_channels: the number of channels, e.g. 512. channel_dim: the axis/dimension corresponding to the channel, interprted as an offset from the input's ndim if negative. shis is NOT the num_channels; it should typically be one of {-2, -1, 0, 1, 2, 3}. - initial_eps: the initial "epsilon" that we add as ballast in: + eps: the initial "epsilon" that we add as ballast in: scale = ((input_vec**2).mean() + epsilon)**-0.5 - Note: our epsilon is actually large, but we keep the name - to indicate the connection with normal LayerNorm. - - speed: a scaling factor that can be interpreted as scaling the learning - rate for this module. CAUTION: the default value of 10.0 intended to be - used with Adam or amsgrad-type optimizers, e.g. Adam or Noam. - If you are using SGD you would probably have to set `speed` to - a value less than one, or the training would be unstable. + Note: our epsilon is actually large, but we keep the name + to indicate the connection with conventional LayerNorm. + learn_eps: if true, we learn epsilon; if false, we keep it + at the initial value. + eps_speed: a constant that determines how fast "eps" learns; + with Adam and variants, this should probably be >= 1, + e.g. 5.0. For SGD and variants, probably a value less than one, + like 0.1, would be suitable, to prevent instability. """ def __init__(self, num_channels: int, channel_dim: int = -1, # CAUTION: see documentation. - eps: float = 0.25): + eps: float = 0.25, + learn_eps: bool = True, + eps_speed: float = 5.0): super(BasicNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim - self.eps = eps + self.eps_speed = eps_speed + if learn_eps: + self.eps = nn.Parameter((torch.tensor(eps).log() / eps_speed).detach()) + else: + self.register_buffer('eps', (torch.tensor(eps).log() / eps_speed).detach()) def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels - scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps) ** -0.5 + scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + + (self.eps * self.eps_speed).exp()) ** -0.5 return x * scales diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 54729652b..8b229a234 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -1129,4 +1129,5 @@ if __name__ == '__main__': seq_len = 20 # Just make sure the forward pass runs. f = c(torch.randn(batch_size, seq_len, feature_dim), - torch.full((batch_size,), seq_len, dtype=torch.int64)) + torch.full((batch_size,), seq_len, dtype=torch.int64), + warmup_mode=True) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 488de3ccc..2af306f94 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv3warmup_embed", + default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv3warmup_embed_scale", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved