diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index d617181db..9985c9001 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -85,6 +85,10 @@ class Zipformer(EncoderInterface): self.zipformer_downsampling_factors = zipformer_downsampling_factors self.output_downsampling_factor = output_downsampling_factor + # will be written to, see set_batch_count() + self.batch_count = 0 + self.warmup_end = warmup_batches + for u,d in zip(encoder_unmasked_dims, encoder_dims): assert u <= d @@ -132,11 +136,53 @@ class Zipformer(EncoderInterface): encoders.append(encoder) self.encoders = nn.ModuleList(encoders) + # initializes self.skip_layers and self.skip_modules + self._init_skip_modules() self.downsample_output = AttentionDownsample(encoder_dims[-1], encoder_dims[-1], downsample=output_downsampling_factor) + + def _get_layer_skip_dropout_prob(self): + if not self.training: + return 0.0 + batch_count = self.batch_count + min_dropout_prob = 0.025 + + if batch_count > self.warmup_end: + return min_dropout_prob + else: + return 0.5 - (batch_count / self.warmup_end) * (0.5 - min_dropout_prob) + + def _init_skip_modules(self): + """ + If self.zipformer_downampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer + indexed 4 (in zero indexing), with has subsapling_factor=4, we combine the output of + layers 2 and 3; and at the input of layer indexed 5, which which has subsampling_factor=2, + we combine the outputs of layers 1 and 5. + """ + skip_layers = [] + skip_modules = [] + z = self.zipformer_downsampling_factors + for i in range(len(z)): + if i <= 1 or z[i-1] <= z[i]: + skip_layers.append(None) + skip_modules.append(nn.Identity()) + else: + # TEMP + for j in range(i-2, -1, -1): + if z[j] <= z[i] or j == 0: + # TEMP logging statement. + logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " + f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}.") + skip_layers.append(j) + skip_modules.append(SimpleCombiner(self.encoder_dims[j], + self.encoder_dims[i-1])) + break + self.skip_layers = skip_layers + self.skip_modules = nn.ModuleList(skip_modules) + def get_feature_masks( self, x: torch.Tensor) -> List[Union[float, Tensor]]: @@ -221,20 +267,25 @@ class Zipformer(EncoderInterface): assert x.size(0) == lengths.max().item() mask = make_pad_mask(lengths) + outputs = [] feature_masks = self.get_feature_masks(x) for i, module in enumerate(self.encoders): ds = self.zipformer_downsampling_factors[i] + if self.skip_layers[i] is not None: + layer_skip_dropout_prob = self._get_layer_skip_dropout_prob() + if (not self.training) or random.random() > layer_skip_dropout_prob: + x = self.skip_modules[i](outputs[self.skip_layers[i]], x) x = module(x, feature_mask=feature_masks[i], src_key_padding_mask=None if mask is None else mask[...,::ds]) + outputs.append(x) x = self.downsample_output(x) # class Downsample has this rounding behavior.. assert self.output_downsampling_factor == 2 lengths = (lengths + 1) // 2 - x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) return x, lengths