mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix bugs
This commit is contained in:
parent
cc81ec4f8a
commit
8001a46758
@ -163,14 +163,14 @@ class Subformer(EncoderInterface):
|
|||||||
)
|
)
|
||||||
|
|
||||||
encoder = encoders[mid]
|
encoder = encoders[mid]
|
||||||
for i in range(mid-1, -1, -1):
|
for i in range(1, mid+1):
|
||||||
this_list = [ encoders[mid-i],
|
this_list = [ encoders[mid-i],
|
||||||
encoder,
|
encoder,
|
||||||
encoders[mid+i] ]
|
encoders[mid+i] ]
|
||||||
encoder = DownsampledSubformerEncoder(
|
encoder = DownsampledSubformerEncoder(
|
||||||
this_list,
|
this_list,
|
||||||
input_num_channels=encoder_dim[max(0, mid-i-1)],
|
input_num_channels=encoder_dim[max(0, mid-i-1)],
|
||||||
downsample=2,
|
downsample=2 if i != mid else 1
|
||||||
)
|
)
|
||||||
|
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
@ -937,8 +937,8 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
|
|
||||||
class DownsampledSubformerEncoder(nn.Module):
|
class DownsampledSubformerEncoder(nn.Module):
|
||||||
"""
|
"""
|
||||||
DownsampledSubformerEncoder is a zipformer encoder evaluated at a reduced frame rate,
|
DownsampledSubformerEncoder is a zipformer encoder stack possibly evaluated at a reduced
|
||||||
after convolutional downsampling, and then upsampled again at the output, and combined
|
frame rate, after convolutional downsampling, and then upsampled again at the output, and combined
|
||||||
with the origin input, so that the output has the same shape as the input.
|
with the origin input, so that the output has the same shape as the input.
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -946,16 +946,18 @@ class DownsampledSubformerEncoder(nn.Module):
|
|||||||
input_num_channels: int,
|
input_num_channels: int,
|
||||||
downsample: int):
|
downsample: int):
|
||||||
super(DownsampledSubformerEncoder, self).__init__()
|
super(DownsampledSubformerEncoder, self).__init__()
|
||||||
self.downsample_factor = downsample
|
|
||||||
|
if downsample != 1:
|
||||||
self.downsampler = LearnedDownsamplingModule(input_num_channels,
|
self.downsampler = LearnedDownsamplingModule(input_num_channels,
|
||||||
downsample)
|
downsample)
|
||||||
|
|
||||||
self.encoders = nn.ModuleList(encoders)
|
self.encoders = nn.ModuleList(encoders)
|
||||||
|
|
||||||
self.out_combiner = BypassModule(max(e.embed_dim() for e in encoders),
|
self.out_combiner = BypassModule(self.embed_dim(),
|
||||||
straight_through_rate=0.0)
|
straight_through_rate=0.0)
|
||||||
|
|
||||||
def embed_dim(self): # return output embed_dim.
|
def embed_dim(self): # return output embed_dim which is max dim.
|
||||||
return self.encoders[-1].embed_dim()
|
return max(e.embed_dim() for e in self.encoders)
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
src: Tensor,
|
src: Tensor,
|
||||||
@ -983,6 +985,8 @@ class DownsampledSubformerEncoder(nn.Module):
|
|||||||
Returns: a Tensor with the same shape as src.
|
Returns: a Tensor with the same shape as src.
|
||||||
"""
|
"""
|
||||||
src_orig = src
|
src_orig = src
|
||||||
|
|
||||||
|
if hasattr(self, 'downsampler'):
|
||||||
indexes, weights, src = self.downsampler(src)
|
indexes, weights, src = self.downsampler(src)
|
||||||
|
|
||||||
pos_emb = self.downsampler.downsample_pos_emb(pos_emb, indexes)
|
pos_emb = self.downsampler.downsample_pos_emb(pos_emb, indexes)
|
||||||
@ -1019,6 +1023,7 @@ class DownsampledSubformerEncoder(nn.Module):
|
|||||||
src = get_full_dim_output()
|
src = get_full_dim_output()
|
||||||
src_orig = convert_num_channels(src_orig, src.shape[-1])
|
src_orig = convert_num_channels(src_orig, src.shape[-1])
|
||||||
|
|
||||||
|
if hasattr(self, 'downsampler'):
|
||||||
src = self.downsampler.upsample(src_orig, src, indexes)
|
src = self.downsampler.upsample(src_orig, src, indexes)
|
||||||
|
|
||||||
return self.out_combiner(src_orig, src)
|
return self.out_combiner(src_orig, src)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user