mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Fix bugs
This commit is contained in:
parent
cc81ec4f8a
commit
8001a46758
@ -163,14 +163,14 @@ class Subformer(EncoderInterface):
|
||||
)
|
||||
|
||||
encoder = encoders[mid]
|
||||
for i in range(mid-1, -1, -1):
|
||||
for i in range(1, mid+1):
|
||||
this_list = [ encoders[mid-i],
|
||||
encoder,
|
||||
encoders[mid+i] ]
|
||||
encoder = DownsampledSubformerEncoder(
|
||||
this_list,
|
||||
input_num_channels=encoder_dim[max(0, mid-i-1)],
|
||||
downsample=2,
|
||||
downsample=2 if i != mid else 1
|
||||
)
|
||||
|
||||
self.encoder = encoder
|
||||
@ -937,8 +937,8 @@ class LearnedDownsamplingModule(nn.Module):
|
||||
|
||||
class DownsampledSubformerEncoder(nn.Module):
|
||||
"""
|
||||
DownsampledSubformerEncoder is a zipformer encoder evaluated at a reduced frame rate,
|
||||
after convolutional downsampling, and then upsampled again at the output, and combined
|
||||
DownsampledSubformerEncoder is a zipformer encoder stack possibly evaluated at a reduced
|
||||
frame rate, after convolutional downsampling, and then upsampled again at the output, and combined
|
||||
with the origin input, so that the output has the same shape as the input.
|
||||
"""
|
||||
def __init__(self,
|
||||
@ -946,16 +946,18 @@ class DownsampledSubformerEncoder(nn.Module):
|
||||
input_num_channels: int,
|
||||
downsample: int):
|
||||
super(DownsampledSubformerEncoder, self).__init__()
|
||||
self.downsample_factor = downsample
|
||||
self.downsampler = LearnedDownsamplingModule(input_num_channels,
|
||||
downsample)
|
||||
|
||||
if downsample != 1:
|
||||
self.downsampler = LearnedDownsamplingModule(input_num_channels,
|
||||
downsample)
|
||||
|
||||
self.encoders = nn.ModuleList(encoders)
|
||||
|
||||
self.out_combiner = BypassModule(max(e.embed_dim() for e in encoders),
|
||||
self.out_combiner = BypassModule(self.embed_dim(),
|
||||
straight_through_rate=0.0)
|
||||
|
||||
def embed_dim(self): # return output embed_dim.
|
||||
return self.encoders[-1].embed_dim()
|
||||
def embed_dim(self): # return output embed_dim which is max dim.
|
||||
return max(e.embed_dim() for e in self.encoders)
|
||||
|
||||
def forward(self,
|
||||
src: Tensor,
|
||||
@ -983,13 +985,15 @@ class DownsampledSubformerEncoder(nn.Module):
|
||||
Returns: a Tensor with the same shape as src.
|
||||
"""
|
||||
src_orig = src
|
||||
indexes, weights, src = self.downsampler(src)
|
||||
|
||||
pos_emb = self.downsampler.downsample_pos_emb(pos_emb, indexes)
|
||||
if hasattr(self, 'downsampler'):
|
||||
indexes, weights, src = self.downsampler(src)
|
||||
|
||||
attn_offset = self.downsampler.downsample_attn_offset(attn_offset,
|
||||
indexes,
|
||||
weights)
|
||||
pos_emb = self.downsampler.downsample_pos_emb(pos_emb, indexes)
|
||||
|
||||
attn_offset = self.downsampler.downsample_attn_offset(attn_offset,
|
||||
indexes,
|
||||
weights)
|
||||
outputs = [ src ]
|
||||
|
||||
for encoder in self.encoders:
|
||||
@ -1019,7 +1023,8 @@ class DownsampledSubformerEncoder(nn.Module):
|
||||
src = get_full_dim_output()
|
||||
src_orig = convert_num_channels(src_orig, src.shape[-1])
|
||||
|
||||
src = self.downsampler.upsample(src_orig, src, indexes)
|
||||
if hasattr(self, 'downsampler'):
|
||||
src = self.downsampler.upsample(src_orig, src, indexes)
|
||||
|
||||
return self.out_combiner(src_orig, src)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user