diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index ca93068a4..855d1c4d8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -129,14 +129,6 @@ class Zipformer2(EncoderInterface): dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) - # this is not the probability of skipping a layer. It is the probability of - # dropping out the "skip module" which allows the model to skip groups of - # encoder stacks; when it's dropped out like this, it means we are forced - # to take the "direct" (non-bypass) path. - self.layer_skip_dropout_prob = ScheduledFloat((0.0, 0.5), - (warmup_batches, 0.025), - (20000.0, 0.0)) - def _to_tuple(x): """ Converts a single int or a 1-tuple of an int to a tuple with the same length as downsampling_factor""" @@ -223,41 +215,10 @@ class Zipformer2(EncoderInterface): self.encoders = nn.ModuleList(encoders) - # initializes self.skip_layers and self.skip_modules - self._init_skip_modules() - self.downsample_output = SimpleDownsample(max(encoder_dim), downsample=output_downsampling_factor, dropout=dropout) - - def _init_skip_modules(self): - """ - If self.downampling_factor = (1, 2, 4, 8, 4, 2), then at the input of layer - indexed 4 (in zero indexing), with has subsapling_factor=4, we combine the output of - layers 2 and 3; and at the input of layer indexed 5, which which has subsampling_factor=2, - we combine the outputs of layers 1 and 5. - """ - skip_layers = [] - skip_modules = [] - z = self.downsampling_factor - for i in range(len(z)): - if i <= 1 or z[i-1] <= z[i]: - skip_layers.append(None) - skip_modules.append(Identity()) - else: - # TEMP - for j in range(i-2, -1, -1): - if z[j] <= z[i] or j == 0: - # TEMP logging statement. - logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " - f"combine the outputs of layers {j} and {i-1}, with downsampling_factor={z[j]} and {z[i-1]}.") - skip_layers.append(j) - skip_modules.append(BypassModule(self.encoder_dim[i])) - break - self.skip_layers = skip_layers - self.skip_modules = nn.ModuleList(skip_modules) - def get_feature_masks( self, x: torch.Tensor) -> List[Union[float, Tensor]]: @@ -384,23 +345,6 @@ class Zipformer2(EncoderInterface): ds = self.downsampling_factor[i] x = convert_num_channels(x, self.encoder_dim[i]) - if self.skip_layers[i] is not None: - # this how we implement U-net-like skipping of some series of - # stacks. The layer_skip_dropout_prob is to discourage it from - # completely ignoring the middle layers, especially early in - # training, - skip_output = convert_num_channels(outputs[self.skip_layers[i]], - self.encoder_dim[i]) - skip_x = self.skip_modules[i](skip_output, x) - - layer_skip_dropout_prob = float(self.layer_skip_dropout_prob) - if self.training and layer_skip_dropout_prob > 0: - batch_size = x.shape[1] - mask = (torch.rand((1, batch_size, 1), device=x.device) > - layer_skip_dropout_prob) - x = torch.where(mask, skip_x, x) - else: - x = skip_x x = module(x, chunk_size=chunk_size, feature_mask=feature_masks[i],