diff --git a/egs/libriheavy/LM/zipformer1/subformer.py b/egs/libriheavy/LM/zipformer1/subformer.py index 9f340f8e7..ee5358324 100644 --- a/egs/libriheavy/LM/zipformer1/subformer.py +++ b/egs/libriheavy/LM/zipformer1/subformer.py @@ -163,14 +163,14 @@ class Subformer(EncoderInterface): ) encoder = encoders[mid] - for i in range(mid-1, -1, -1): + for i in range(1, mid+1): this_list = [ encoders[mid-i], encoder, encoders[mid+i] ] encoder = DownsampledSubformerEncoder( this_list, input_num_channels=encoder_dim[max(0, mid-i-1)], - downsample=2, + downsample=2 if i != mid else 1 ) self.encoder = encoder @@ -937,8 +937,8 @@ class LearnedDownsamplingModule(nn.Module): class DownsampledSubformerEncoder(nn.Module): """ - DownsampledSubformerEncoder is a zipformer encoder evaluated at a reduced frame rate, - after convolutional downsampling, and then upsampled again at the output, and combined + DownsampledSubformerEncoder is a zipformer encoder stack possibly evaluated at a reduced + frame rate, after convolutional downsampling, and then upsampled again at the output, and combined with the origin input, so that the output has the same shape as the input. """ def __init__(self, @@ -946,16 +946,18 @@ class DownsampledSubformerEncoder(nn.Module): input_num_channels: int, downsample: int): super(DownsampledSubformerEncoder, self).__init__() - self.downsample_factor = downsample - self.downsampler = LearnedDownsamplingModule(input_num_channels, - downsample) + + if downsample != 1: + self.downsampler = LearnedDownsamplingModule(input_num_channels, + downsample) + self.encoders = nn.ModuleList(encoders) - self.out_combiner = BypassModule(max(e.embed_dim() for e in encoders), + self.out_combiner = BypassModule(self.embed_dim(), straight_through_rate=0.0) - def embed_dim(self): # return output embed_dim. - return self.encoders[-1].embed_dim() + def embed_dim(self): # return output embed_dim which is max dim. + return max(e.embed_dim() for e in self.encoders) def forward(self, src: Tensor, @@ -983,13 +985,15 @@ class DownsampledSubformerEncoder(nn.Module): Returns: a Tensor with the same shape as src. """ src_orig = src - indexes, weights, src = self.downsampler(src) - pos_emb = self.downsampler.downsample_pos_emb(pos_emb, indexes) + if hasattr(self, 'downsampler'): + indexes, weights, src = self.downsampler(src) - attn_offset = self.downsampler.downsample_attn_offset(attn_offset, - indexes, - weights) + pos_emb = self.downsampler.downsample_pos_emb(pos_emb, indexes) + + attn_offset = self.downsampler.downsample_attn_offset(attn_offset, + indexes, + weights) outputs = [ src ] for encoder in self.encoders: @@ -1019,7 +1023,8 @@ class DownsampledSubformerEncoder(nn.Module): src = get_full_dim_output() src_orig = convert_num_channels(src_orig, src.shape[-1]) - src = self.downsampler.upsample(src_orig, src, indexes) + if hasattr(self, 'downsampler'): + src = self.downsampler.upsample(src_orig, src, indexes) return self.out_combiner(src_orig, src)