mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove query from AttentionDownsample, rename to SimpleDownsample
This commit is contained in:
parent
ed7e01448c
commit
86bb0623e9
@ -209,7 +209,7 @@ class Zipformer(EncoderInterface):
|
|||||||
# initializes self.skip_layers and self.skip_modules
|
# initializes self.skip_layers and self.skip_modules
|
||||||
self._init_skip_modules()
|
self._init_skip_modules()
|
||||||
|
|
||||||
self.downsample_output = AttentionDownsample(encoder_dim[-1],
|
self.downsample_output = SimpleDownsample(encoder_dim[-1],
|
||||||
encoder_dim[-1],
|
encoder_dim[-1],
|
||||||
downsample=output_downsampling_factor,
|
downsample=output_downsampling_factor,
|
||||||
dropout=dropout)
|
dropout=dropout)
|
||||||
@ -677,7 +677,7 @@ class DownsampledZipformerEncoder(nn.Module):
|
|||||||
dropout: FloatLike):
|
dropout: FloatLike):
|
||||||
super(DownsampledZipformerEncoder, self).__init__()
|
super(DownsampledZipformerEncoder, self).__init__()
|
||||||
self.downsample_factor = downsample
|
self.downsample_factor = downsample
|
||||||
self.downsample = AttentionDownsample(input_dim, output_dim,
|
self.downsample = SimpleDownsample(input_dim, output_dim,
|
||||||
downsample, dropout)
|
downsample, dropout)
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
self.upsample = SimpleUpsample(output_dim, downsample)
|
self.upsample = SimpleUpsample(output_dim, downsample)
|
||||||
@ -741,7 +741,7 @@ class DownsamplingZipformerEncoder(nn.Module):
|
|||||||
downsample: int):
|
downsample: int):
|
||||||
super(DownsampledZipformerEncoder, self).__init__()
|
super(DownsampledZipformerEncoder, self).__init__()
|
||||||
self.downsample_factor = downsample
|
self.downsample_factor = downsample
|
||||||
self.downsample = AttentionDownsample(input_dim, output_dim, downsample)
|
self.downsample = SimpleDownsample(input_dim, output_dim, downsample)
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
|
|
||||||
|
|
||||||
@ -785,7 +785,7 @@ class DownsamplingZipformerEncoder(nn.Module):
|
|||||||
return src
|
return src
|
||||||
|
|
||||||
|
|
||||||
class AttentionDownsample(torch.nn.Module):
|
class SimpleDownsample(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Does downsampling with attention, by weighted sum, and a projection..
|
Does downsampling with attention, by weighted sum, and a projection..
|
||||||
"""
|
"""
|
||||||
@ -797,7 +797,7 @@ class AttentionDownsample(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
Require out_channels > in_channels.
|
Require out_channels > in_channels.
|
||||||
"""
|
"""
|
||||||
super(AttentionDownsample, self).__init__()
|
super(SimpleDownsample, self).__init__()
|
||||||
|
|
||||||
self.bias = nn.Parameter(torch.zeros(downsample))
|
self.bias = nn.Parameter(torch.zeros(downsample))
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user