diff --git a/egs/libriheavy/LM/zipformer1/subformer.py b/egs/libriheavy/LM/zipformer1/subformer.py index 48e632457..9f340f8e7 100644 --- a/egs/libriheavy/LM/zipformer1/subformer.py +++ b/egs/libriheavy/LM/zipformer1/subformer.py @@ -169,7 +169,7 @@ class Subformer(EncoderInterface): encoders[mid+i] ] encoder = DownsampledSubformerEncoder( this_list, - input_num_channels=encoder_dim[max(0, mid-2)], + input_num_channels=encoder_dim[max(0, mid-i-1)], downsample=2, ) @@ -954,6 +954,9 @@ class DownsampledSubformerEncoder(nn.Module): self.out_combiner = BypassModule(max(e.embed_dim() for e in encoders), straight_through_rate=0.0) + def embed_dim(self): # return output embed_dim. + return self.encoders[-1].embed_dim() + def forward(self, src: Tensor, pos_emb: Tensor,