mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
bug fix
This commit is contained in:
parent
0a76215fd7
commit
cc81ec4f8a
@ -169,7 +169,7 @@ class Subformer(EncoderInterface):
|
|||||||
encoders[mid+i] ]
|
encoders[mid+i] ]
|
||||||
encoder = DownsampledSubformerEncoder(
|
encoder = DownsampledSubformerEncoder(
|
||||||
this_list,
|
this_list,
|
||||||
input_num_channels=encoder_dim[max(0, mid-2)],
|
input_num_channels=encoder_dim[max(0, mid-i-1)],
|
||||||
downsample=2,
|
downsample=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -954,6 +954,9 @@ class DownsampledSubformerEncoder(nn.Module):
|
|||||||
self.out_combiner = BypassModule(max(e.embed_dim() for e in encoders),
|
self.out_combiner = BypassModule(max(e.embed_dim() for e in encoders),
|
||||||
straight_through_rate=0.0)
|
straight_through_rate=0.0)
|
||||||
|
|
||||||
|
def embed_dim(self): # return output embed_dim.
|
||||||
|
return self.encoders[-1].embed_dim()
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
src: Tensor,
|
src: Tensor,
|
||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user