This commit is contained in:
Daniel Povey 2023-05-15 22:07:27 +08:00
parent 0a76215fd7
commit cc81ec4f8a

View File

@ -169,7 +169,7 @@ class Subformer(EncoderInterface):
encoders[mid+i] ]
encoder = DownsampledSubformerEncoder(
this_list,
input_num_channels=encoder_dim[max(0, mid-2)],
input_num_channels=encoder_dim[max(0, mid-i-1)],
downsample=2,
)
@ -954,6 +954,9 @@ class DownsampledSubformerEncoder(nn.Module):
self.out_combiner = BypassModule(max(e.embed_dim() for e in encoders),
straight_through_rate=0.0)
def embed_dim(self): # return output embed_dim.
return self.encoders[-1].embed_dim()
def forward(self,
src: Tensor,
pos_emb: Tensor,