diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 0d6497326..b786f5068 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -210,7 +210,7 @@ class Zipformer(EncoderInterface): # initializes self.skip_layers and self.skip_modules self._init_skip_modules() - self.downsample_output = AttentionDownsample(encoder_dim[-1], + self.downsample_output = SimpleDownsample(encoder_dim[-1], encoder_dim[-1], downsample=output_downsampling_factor, dropout=dropout) @@ -678,7 +678,7 @@ class DownsampledZipformerEncoder(nn.Module): dropout: FloatLike): super(DownsampledZipformerEncoder, self).__init__() self.downsample_factor = downsample - self.downsample = AttentionDownsample(input_dim, output_dim, + self.downsample = SimpleDownsample(input_dim, output_dim, downsample, dropout) self.encoder = encoder self.upsample = SimpleUpsample(output_dim, downsample) @@ -731,7 +731,7 @@ class DownsampledZipformerEncoder(nn.Module): -class AttentionDownsample(torch.nn.Module): +class SimpleDownsample(torch.nn.Module): """ Does downsampling with attention, by weighted sum, and a projection.. """ @@ -743,8 +743,8 @@ class AttentionDownsample(torch.nn.Module): """ Require out_channels > in_channels. """ - super(AttentionDownsample, self).__init__() - self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5)) + super(SimpleDownsample, self).__init__() + self.bias = nn.Parameter(torch.zeros(downsample)) self.name = None # will be set from training code @@ -779,24 +779,10 @@ class AttentionDownsample(torch.nn.Module): assert src.shape[0] == d_seq_len * ds src = src.reshape(d_seq_len, ds, batch_size, in_channels) - # scores: (d_seq_len, downsample, batch_size, 1) - scores = (src * self.query).sum(dim=-1, keepdim=True) - scores = scores + self.bias.unsqueeze(-1).unsqueeze(-1) - scores = penalize_abs_values_gt(scores, - limit=20.0, - penalty=1.0e-04, - name=self.name) - - dropout = float(self.dropout) - if dropout > 0.0: - # the 0:1, done on the axis of size 'downsample', selects just - # one dimension while keeping the dim. We'll then broadcast when - # we multiply. - dropout_mask = torch.rand_like(scores[:, 0:1]) > dropout - scores = scores * dropout_mask - - weights = scores.softmax(dim=1) + weights = self.bias.softmax(dim=0) + # weights: (downsample, 1, 1) + weights = weights.unsqueeze(-1).unsqueeze(-1) # ans1 is the first `in_channels` channels of the output ans = (src * weights).sum(dim=1)