Remove query from AttentionDownsample, rename to SimpleDownsample

This commit is contained in:
Daniel Povey 2022-12-17 13:45:30 +08:00
parent ed7e01448c
commit 86bb0623e9

View File

@ -209,7 +209,7 @@ class Zipformer(EncoderInterface):
# initializes self.skip_layers and self.skip_modules
self._init_skip_modules()
self.downsample_output = AttentionDownsample(encoder_dim[-1],
self.downsample_output = SimpleDownsample(encoder_dim[-1],
encoder_dim[-1],
downsample=output_downsampling_factor,
dropout=dropout)
@ -677,7 +677,7 @@ class DownsampledZipformerEncoder(nn.Module):
dropout: FloatLike):
super(DownsampledZipformerEncoder, self).__init__()
self.downsample_factor = downsample
self.downsample = AttentionDownsample(input_dim, output_dim,
self.downsample = SimpleDownsample(input_dim, output_dim,
downsample, dropout)
self.encoder = encoder
self.upsample = SimpleUpsample(output_dim, downsample)
@ -741,7 +741,7 @@ class DownsamplingZipformerEncoder(nn.Module):
downsample: int):
super(DownsampledZipformerEncoder, self).__init__()
self.downsample_factor = downsample
self.downsample = AttentionDownsample(input_dim, output_dim, downsample)
self.downsample = SimpleDownsample(input_dim, output_dim, downsample)
self.encoder = encoder
@ -785,7 +785,7 @@ class DownsamplingZipformerEncoder(nn.Module):
return src
class AttentionDownsample(torch.nn.Module):
class SimpleDownsample(torch.nn.Module):
"""
Does downsampling with attention, by weighted sum, and a projection..
"""
@ -797,7 +797,7 @@ class AttentionDownsample(torch.nn.Module):
"""
Require out_channels > in_channels.
"""
super(AttentionDownsample, self).__init__()
super(SimpleDownsample, self).__init__()
self.bias = nn.Parameter(torch.zeros(downsample))