diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 95eed36e4..19c52be38 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -84,6 +84,15 @@ class Zipformer(EncoderInterface): self.zipformer_downsampling_factors = zipformer_downsampling_factors self.output_downsampling_factor = output_downsampling_factor + # keep track of how many times forward() has been called, for purposes + # of warmup. do this with a floating-point count because integer counts + # can fail to survive model averaging. initialize with a smallish + # random number so that different encoders use different random seeds in + # shared_rng get_layers_to_drop() while using the same random seeds + # across jobs. + self.register_buffer('warmup_count', torch.tensor(float(10.0 * random.random()))) + self.warmup_end = warmup_batches + for u,d in zip(encoder_unmasked_dims, encoder_dims): assert u <= d @@ -138,6 +147,35 @@ class Zipformer(EncoderInterface): encoder_dims[-1], downsample=output_downsampling_factor) + + def get_warmup_count(self) -> float: + """ + Returns a value that reflects how many times this function has been called in training mode. + """ + ans = self.warmup_count.item() + if self.training: + if ans > 1000000.0: + # this ensures that as the number of batches gets large, the warmup count cycles rather + # than getting stuck at the smallest floating point value x such that x + 1 == x. + # this is necessary because get_layers_to_drop() relies on the warmup count changing + # on every batch. + next_count = 500000.0 + else: + next_count = ans + 1.0 + self.warmup_count.fill_(next_count) + return ans + + def _get_layer_skip_dropout_prob(self): + if not self.training: + return 0.0 + warmup_count = self.get_warmup_count() + min_dropout_prob = 0.025 + + if warmup_count > self.warmup_end: + return min_dropout_prob + else: + return 0.5 - (warmup_count / self.warmup_end) * (0.5 - min_dropout_prob) + def _init_skip_modules(self): """ If self.zipformer_downampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer @@ -256,7 +294,7 @@ class Zipformer(EncoderInterface): for i, module in enumerate(self.encoders): ds = self.zipformer_downsampling_factors[i] if self.skip_layers[i] is not None: - layer_skip_dropout_prob = 0.05 + layer_skip_dropout_prob = self._get_layer_skip_dropout_prob() if (not self.training) or random.random() > layer_skip_dropout_prob: x = self.skip_modules[i](outputs[self.skip_layers[i]], x) x = module(x,