From 86bb0623e9428822e3c79bb1bebc250581f75eb5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 17 Dec 2022 13:45:30 +0800 Subject: [PATCH] Remove query from AttentionDownsample, rename to SimpleDownsample --- .../ASR/pruned_transducer_stateless7/zipformer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index b000e9062..1475bda5c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -209,7 +209,7 @@ class Zipformer(EncoderInterface): # initializes self.skip_layers and self.skip_modules self._init_skip_modules() - self.downsample_output = AttentionDownsample(encoder_dim[-1], + self.downsample_output = SimpleDownsample(encoder_dim[-1], encoder_dim[-1], downsample=output_downsampling_factor, dropout=dropout) @@ -677,7 +677,7 @@ class DownsampledZipformerEncoder(nn.Module): dropout: FloatLike): super(DownsampledZipformerEncoder, self).__init__() self.downsample_factor = downsample - self.downsample = AttentionDownsample(input_dim, output_dim, + self.downsample = SimpleDownsample(input_dim, output_dim, downsample, dropout) self.encoder = encoder self.upsample = SimpleUpsample(output_dim, downsample) @@ -741,7 +741,7 @@ class DownsamplingZipformerEncoder(nn.Module): downsample: int): super(DownsampledZipformerEncoder, self).__init__() self.downsample_factor = downsample - self.downsample = AttentionDownsample(input_dim, output_dim, downsample) + self.downsample = SimpleDownsample(input_dim, output_dim, downsample) self.encoder = encoder @@ -785,7 +785,7 @@ class DownsamplingZipformerEncoder(nn.Module): return src -class AttentionDownsample(torch.nn.Module): +class SimpleDownsample(torch.nn.Module): """ Does downsampling with attention, by weighted sum, and a projection.. """ @@ -797,7 +797,7 @@ class AttentionDownsample(torch.nn.Module): """ Require out_channels > in_channels. """ - super(AttentionDownsample, self).__init__() + super(SimpleDownsample, self).__init__() self.bias = nn.Parameter(torch.zeros(downsample))