From e9c69d847723764890c3b6a5d0047db75a13d7d9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 30 Oct 2022 14:41:18 +0800 Subject: [PATCH] Add warmup schedule for zipformer encoder layer, from 1.0 -> 0.2. --- .../pruned_transducer_stateless7/zipformer.py | 38 +++++++++++++------ 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index c7cc4440a..2a4de2c48 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -272,6 +272,9 @@ class ZipformerEncoderLayer(nn.Module): self.d_model = d_model + # will be written to, see set_batch_count() + self.batch_count = 0 + self.self_attn = RelPositionMultiheadAttention( d_model, attention_dim, nhead, pos_dim, dropout=0.0, ) @@ -310,6 +313,25 @@ class ZipformerEncoderLayer(nn.Module): prob=(0.025, 0.25), grad_scale=0.01) + def get_bypass_scale(self): + if torch.jit.is_scripting() or not self.training: + return self.bypass_scale + if random.random() < 0.1: + # ensure we get grads if self.bypass_scale becomes out of range + return self.bypass_scale + # hardcode warmup period for bypass scale + warmup_period = 4000.0 + initial_clamp_min = 1.0 + final_clamp_min = 0.2 + if self.batch_count > warmup_period: + clamp_min = final_clamp_min + else: + clamp_min = (initial_clamp_min - + (self.batch_count / warmup_period) * (initial_clamp_min - final_clamp_min)) + return self.bypass_scale.clamp(min=clamp_min, max=1.0) + + + def forward( self, src: Tensor, @@ -363,14 +385,8 @@ class ZipformerEncoderLayer(nn.Module): src = self.norm_final(self.balancer(src)) delta = src - src_orig - bypass_scale = self.bypass_scale - if torch.jit.is_scripting() or (not self.training) or random.random() > 0.1: - # with probability 0.9, in training mode, or always, in testing - # mode, clamp bypass_scale to [ 0.1, 1.0 ]; this will encourage it - # to learn parameters within this range by making parameters that - # are outside that range range noisy. - bypass_scale = bypass_scale.clamp(min=0.5, max=1.0) - src = src_orig + delta * bypass_scale + + src = src_orig + delta * self.get_bypass_scale() return self.whiten(src) @@ -397,9 +413,9 @@ class ZipformerEncoder(nn.Module): warmup_end: float ) -> None: super().__init__() - # self.batch_count will be written to by the top-level training program. - # Note: in inference time this may be zero but should be treated as large, - # we can check if self.training is true. + # will be written to, see set_batch_count() Note: in inference time this + # may be zero but should be treated as large, we can check if + # self.training is true. self.batch_count = 0 self.warmup_begin = warmup_begin self.warmup_end = warmup_end