From bf37c7ca85602d89fa9229f9ece5f4533e1a88a6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 26 Oct 2022 12:51:20 +0800 Subject: [PATCH] Regularize how we apply the min and max to the eps of BasicNorm --- .../pruned_transducer_stateless7/conformer.py | 6 +++++- .../ASR/pruned_transducer_stateless7/optim.py | 2 +- .../ASR/pruned_transducer_stateless7/scaling.py | 16 ++++++++++++++++ 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index fb9549a00..8fc20915b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -360,7 +360,11 @@ class ConformerEncoderLayer(nn.Module): delta = src - src_orig bypass_scale = self.bypass_scale - if random.random() > 0.1: + if self.training and random.random() < 0.25: + # with probability 0.25, in training mode, clamp bypass_scale to [ + # 0.1, 1.0 ]; this will encourage it to learn parameters within this + # range by making parameters that are outside that range range + # noisy. bypass_scale = bypass_scale.clamp(min=0.1, max=1.0) src = src_orig + delta * self.bypass_scale diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 25295bf16..bb8b0a0e3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -159,7 +159,7 @@ class ScaledAdam(BatchedOptimizer): eps=1.0e-08, param_min_rms=1.0e-05, param_max_rms=3.0, - scalar_max=5.0, + scalar_max=10.0, size_update_period=4, clipping_update_period=100, ): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index a2200c04b..72dfaf446 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -348,6 +348,8 @@ class BasicNorm(torch.nn.Module): to indicate the connection with conventional LayerNorm. learn_eps: if true, we learn epsilon; if false, we keep it at the initial value. + eps_min: float + eps_max: float """ def __init__( @@ -356,6 +358,8 @@ class BasicNorm(torch.nn.Module): channel_dim: int = -1, # CAUTION: see documentation. eps: float = 0.25, learn_eps: bool = True, + eps_min: float = -3.0, + eps_max: float = 3.0, ) -> None: super(BasicNorm, self).__init__() self.num_channels = num_channels @@ -364,9 +368,21 @@ class BasicNorm(torch.nn.Module): self.eps = nn.Parameter(torch.tensor(eps).log().detach()) else: self.register_buffer("eps", torch.tensor(eps).log().detach()) + self.eps_min = eps_min + self.eps_max = eps_max def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels + eps = self.eps + if self.training and random.random() < 0.25: + # with probability 0.25, in training mode, clamp eps between the min + # and max; this will encourage it to learn parameters within the + # allowed range by making parameters that are outside the allowed + # range noisy. + + # gradients to allow the parameter to get back into the allowed + # region if it happens to exit it. + eps = eps.clamp(min=self.eps_min, max=self.eps_max) scales = ( torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + self.eps.exp()