This commit is contained in:
Daniel Povey 2023-05-15 22:49:43 +08:00
parent cc81ec4f8a
commit 8001a46758

View File

@ -163,14 +163,14 @@ class Subformer(EncoderInterface):
)
encoder = encoders[mid]
for i in range(mid-1, -1, -1):
for i in range(1, mid+1):
this_list = [ encoders[mid-i],
encoder,
encoders[mid+i] ]
encoder = DownsampledSubformerEncoder(
this_list,
input_num_channels=encoder_dim[max(0, mid-i-1)],
downsample=2,
downsample=2 if i != mid else 1
)
self.encoder = encoder
@ -937,8 +937,8 @@ class LearnedDownsamplingModule(nn.Module):
class DownsampledSubformerEncoder(nn.Module):
"""
DownsampledSubformerEncoder is a zipformer encoder evaluated at a reduced frame rate,
after convolutional downsampling, and then upsampled again at the output, and combined
DownsampledSubformerEncoder is a zipformer encoder stack possibly evaluated at a reduced
frame rate, after convolutional downsampling, and then upsampled again at the output, and combined
with the origin input, so that the output has the same shape as the input.
"""
def __init__(self,
@ -946,16 +946,18 @@ class DownsampledSubformerEncoder(nn.Module):
input_num_channels: int,
downsample: int):
super(DownsampledSubformerEncoder, self).__init__()
self.downsample_factor = downsample
self.downsampler = LearnedDownsamplingModule(input_num_channels,
downsample)
if downsample != 1:
self.downsampler = LearnedDownsamplingModule(input_num_channels,
downsample)
self.encoders = nn.ModuleList(encoders)
self.out_combiner = BypassModule(max(e.embed_dim() for e in encoders),
self.out_combiner = BypassModule(self.embed_dim(),
straight_through_rate=0.0)
def embed_dim(self): # return output embed_dim.
return self.encoders[-1].embed_dim()
def embed_dim(self): # return output embed_dim which is max dim.
return max(e.embed_dim() for e in self.encoders)
def forward(self,
src: Tensor,
@ -983,13 +985,15 @@ class DownsampledSubformerEncoder(nn.Module):
Returns: a Tensor with the same shape as src.
"""
src_orig = src
indexes, weights, src = self.downsampler(src)
pos_emb = self.downsampler.downsample_pos_emb(pos_emb, indexes)
if hasattr(self, 'downsampler'):
indexes, weights, src = self.downsampler(src)
attn_offset = self.downsampler.downsample_attn_offset(attn_offset,
indexes,
weights)
pos_emb = self.downsampler.downsample_pos_emb(pos_emb, indexes)
attn_offset = self.downsampler.downsample_attn_offset(attn_offset,
indexes,
weights)
outputs = [ src ]
for encoder in self.encoders:
@ -1019,7 +1023,8 @@ class DownsampledSubformerEncoder(nn.Module):
src = get_full_dim_output()
src_orig = convert_num_channels(src_orig, src.shape[-1])
src = self.downsampler.upsample(src_orig, src, indexes)
if hasattr(self, 'downsampler'):
src = self.downsampler.upsample(src_orig, src, indexes)
return self.out_combiner(src_orig, src)