Add skip connections as in normal U-net

This commit is contained in:
Daniel Povey 2022-10-29 19:47:10 +08:00
parent ff03ec88a5
commit 05689f6354

View File

@ -131,11 +131,41 @@ class Zipformer(EncoderInterface):
encoders.append(encoder)
self.encoders = nn.ModuleList(encoders)
# initializes self.skip_layers and self.skip_modules
self._init_skip_modules()
self.downsample_output = AttentionDownsample(encoder_dims[-1],
encoder_dims[-1],
downsample=output_downsampling_factor)
def _init_skip_modules(self):
"""
If self.zipformer_downampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer
indexed 4 (in zero indexing), with has subsapling_factor=4, we combine the output of
layers 2 and 3; and at the input of layer indexed 5, which which has subsampling_factor=2,
we combine the outputs of layers 1 and 5.
"""
skip_layers = []
skip_modules = []
z = self.zipformer_downsampling_factors
for i in range(len(z)):
if i <= 1 or z[i-1] <= z[i]:
skip_layers.append(None)
skip_modules.append(nn.Identity())
else:
# TEMP
for j in range(i-2, -1, -1):
if z[j] <= z[i]:
# TEMP logging statement.
logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will "
f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}.")
skip_layers.append(j)
skip_modules.append(SimpleCombiner(self.encoder_dims[j],
self.encoder_dims[i-1]))
break
self.skip_layers = skip_layers
self.skip_modules = nn.ModuleList(skip_modules)
def get_feature_masks(
self,
x: torch.Tensor) -> List[Union[float, Tensor]]:
@ -220,20 +250,23 @@ class Zipformer(EncoderInterface):
assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)
outputs = []
feature_masks = self.get_feature_masks(x)
for i, module in enumerate(self.encoders):
ds = self.zipformer_downsampling_factors[i]
if self.skip_layers[i] is not None:
x = self.skip_modules[i](outputs[self.skip_layers[i]], x)
x = module(x,
feature_mask=feature_masks[i],
src_key_padding_mask=None if mask is None else mask[...,::ds])
outputs.append(x)
x = self.downsample_output(x)
# class Downsample has this rounding behavior..
assert self.output_downsampling_factor == 2
lengths = (lengths + 1) // 2
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return x, lengths