This commit is contained in:
Daniel Povey 2023-05-15 22:49:43 +08:00
parent cc81ec4f8a
commit 8001a46758

View File

@ -163,14 +163,14 @@ class Subformer(EncoderInterface):
) )
encoder = encoders[mid] encoder = encoders[mid]
for i in range(mid-1, -1, -1): for i in range(1, mid+1):
this_list = [ encoders[mid-i], this_list = [ encoders[mid-i],
encoder, encoder,
encoders[mid+i] ] encoders[mid+i] ]
encoder = DownsampledSubformerEncoder( encoder = DownsampledSubformerEncoder(
this_list, this_list,
input_num_channels=encoder_dim[max(0, mid-i-1)], input_num_channels=encoder_dim[max(0, mid-i-1)],
downsample=2, downsample=2 if i != mid else 1
) )
self.encoder = encoder self.encoder = encoder
@ -937,8 +937,8 @@ class LearnedDownsamplingModule(nn.Module):
class DownsampledSubformerEncoder(nn.Module): class DownsampledSubformerEncoder(nn.Module):
""" """
DownsampledSubformerEncoder is a zipformer encoder evaluated at a reduced frame rate, DownsampledSubformerEncoder is a zipformer encoder stack possibly evaluated at a reduced
after convolutional downsampling, and then upsampled again at the output, and combined frame rate, after convolutional downsampling, and then upsampled again at the output, and combined
with the origin input, so that the output has the same shape as the input. with the origin input, so that the output has the same shape as the input.
""" """
def __init__(self, def __init__(self,
@ -946,16 +946,18 @@ class DownsampledSubformerEncoder(nn.Module):
input_num_channels: int, input_num_channels: int,
downsample: int): downsample: int):
super(DownsampledSubformerEncoder, self).__init__() super(DownsampledSubformerEncoder, self).__init__()
self.downsample_factor = downsample
self.downsampler = LearnedDownsamplingModule(input_num_channels, if downsample != 1:
downsample) self.downsampler = LearnedDownsamplingModule(input_num_channels,
downsample)
self.encoders = nn.ModuleList(encoders) self.encoders = nn.ModuleList(encoders)
self.out_combiner = BypassModule(max(e.embed_dim() for e in encoders), self.out_combiner = BypassModule(self.embed_dim(),
straight_through_rate=0.0) straight_through_rate=0.0)
def embed_dim(self): # return output embed_dim. def embed_dim(self): # return output embed_dim which is max dim.
return self.encoders[-1].embed_dim() return max(e.embed_dim() for e in self.encoders)
def forward(self, def forward(self,
src: Tensor, src: Tensor,
@ -983,13 +985,15 @@ class DownsampledSubformerEncoder(nn.Module):
Returns: a Tensor with the same shape as src. Returns: a Tensor with the same shape as src.
""" """
src_orig = src src_orig = src
indexes, weights, src = self.downsampler(src)
pos_emb = self.downsampler.downsample_pos_emb(pos_emb, indexes) if hasattr(self, 'downsampler'):
indexes, weights, src = self.downsampler(src)
attn_offset = self.downsampler.downsample_attn_offset(attn_offset, pos_emb = self.downsampler.downsample_pos_emb(pos_emb, indexes)
indexes,
weights) attn_offset = self.downsampler.downsample_attn_offset(attn_offset,
indexes,
weights)
outputs = [ src ] outputs = [ src ]
for encoder in self.encoders: for encoder in self.encoders:
@ -1019,7 +1023,8 @@ class DownsampledSubformerEncoder(nn.Module):
src = get_full_dim_output() src = get_full_dim_output()
src_orig = convert_num_channels(src_orig, src.shape[-1]) src_orig = convert_num_channels(src_orig, src.shape[-1])
src = self.downsampler.upsample(src_orig, src, indexes) if hasattr(self, 'downsampler'):
src = self.downsampler.upsample(src_orig, src, indexes)
return self.out_combiner(src_orig, src) return self.out_combiner(src_orig, src)