From ed1b4d5e5d1b74316b437e63f4607507a06511b1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 28 Oct 2022 17:32:38 +0800 Subject: [PATCH] Refactor zipformer for more flexibility so we can change number of encoder layers. --- .../ASR/pruned_transducer_stateless7/train.py | 22 +- .../pruned_transducer_stateless7/zipformer.py | 247 +++++++++++------- 2 files changed, 162 insertions(+), 107 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 523fa83bf..c03598895 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -127,19 +127,19 @@ def add_model_arguments(parser: argparse.ArgumentParser): ) parser.add_argument( - "--encoder-unmasked-dim", - type=int, - default=256, - help="Unmasked dimension in the encoder, relates to augmentation during training. " + "--encoder-unmasked-dims", + type=str, + default="256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " " worse." ) parser.add_argument( - "--zipformer-subsampling-factor", - type=int, - default=2, - help="Subsampling factor for 2nd stack of encoder layers.", + "--zipformer-downsampling-factors", + type=str, + default="1,2", + help="Downsampling factor for each stack of encoder layers.", ) parser.add_argument( @@ -437,10 +437,10 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: encoder = Zipformer( num_features=params.feature_dim, subsampling_factor=params.subsampling_factor, - zipformer_subsampling_factor=params.zipformer_subsampling_factor, - d_model=to_int_tuple(params.encoder_dims), + zipformer_downsampling_factors=to_int_tuple(params.zipformer_downsampling_factors), + encoder_dims=to_int_tuple(params.encoder_dims), attention_dim=to_int_tuple(params.attention_dims), - encoder_unmasked_dim=params.encoder_unmasked_dim, + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), nhead=to_int_tuple(params.nhead), feedforward_dim=to_int_tuple(params.feedforward_dims), num_encoder_layers=to_int_tuple(params.num_encoder_layers), diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index e921de326..70f1a71d0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -64,9 +64,10 @@ class Zipformer(EncoderInterface): num_features: int, subsampling_factor: int = 4, zipformer_subsampling_factor: int = 4, - d_model: Tuple[int] = (384, 384), + encoder_dims: Tuple[int] = (384, 384), attention_dim: Tuple[int] = (256, 256), - encoder_unmasked_dim: int = 256, + encoder_unmasked_dims: Tuple[int] = (256, 256), + zipformer_downsampling_factors: Tuple[int] = (1, 2), nhead: Tuple[int] = (8, 8), feedforward_dim: Tuple[int] = (1536, 2048), num_encoder_layers: Tuple[int] = (12, 12), @@ -78,74 +79,69 @@ class Zipformer(EncoderInterface): self.num_features = num_features self.subsampling_factor = subsampling_factor - self.encoder_unmasked_dim = encoder_unmasked_dim - assert 0 < d_model[0] <= d_model[1] - self.d_model = d_model - self.zipformer_subsampling_factor = zipformer_subsampling_factor + self.encoder_unmasked_dims = encoder_unmasked_dims + assert 0 < encoder_dims[0] <= encoder_dims[1] + self.encoder_dims = encoder_dims + self.encoder_unmasked_dims = encoder_unmasked_dims + self.zipformer_downsampling_factors = zipformer_downsampling_factors - assert encoder_unmasked_dim <= d_model[0] and encoder_unmasked_dim <= d_model[1] + for u,d in zip(encoder_unmasked_dims, encoder_dims): + assert u <= d if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") # self.encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, T//subsampling_factor, d_model). + # to the shape (N, T//subsampling_factor, encoder_dims). # That is, it does two things simultaneously: # (1) subsampling: T -> T//subsampling_factor - # (2) embedding: num_features -> d_model - self.encoder_embed = Conv2dSubsampling(num_features, d_model[0], + # (2) embedding: num_features -> encoder_dims + self.encoder_embed = Conv2dSubsampling(num_features, encoder_dims[0], dropout=dropout) - encoder_layer1 = ZipformerEncoderLayer( - d_model[0], - attention_dim[0], - nhead[0], - feedforward_dim[0], - dropout, - cnn_module_kernel[0], - ) - # for the first third of the warmup period, we let the Conv2dSubsampling - # layer learn something. then start warmup up the first and then the second - # encoder. - self.encoder1 = ZipformerEncoder( - encoder_layer1, - num_encoder_layers[0], - dropout, - warmup_begin=warmup_batches / 3, - warmup_end=warmup_batches * 2 / 3, - ) - encoder_layer2 = ZipformerEncoderLayer( - d_model[1], - attention_dim[1], - nhead[1], - feedforward_dim[1], - dropout, - cnn_module_kernel[1], + # each one will be ZipformerEncoder or DownsampledZipformerEncoder + encoders = [] - ) - self.encoder2 = DownsampledZipformerEncoder( - ZipformerEncoder( - encoder_layer2, - num_encoder_layers[1], + num_encoders = len(encoder_dims) + for i in range(num_encoders): + encoder_layer = ZipformerEncoderLayer( + encoder_dims[i], + attention_dim[i], + nhead[i], + feedforward_dim[i], dropout, - warmup_begin=warmup_batches * 2 / 3, - warmup_end=warmup_batches, - ), - input_dim=d_model[0], - output_dim=d_model[1], - downsample=zipformer_subsampling_factor, - ) + cnn_module_kernel[i], + ) - self.out_combiner = SimpleCombiner(d_model[0], - d_model[1]) + # For the segment of the warmup period, we let the Conv2dSubsampling + # layer learn something. Then we start to warm up the other encoders. + encoder = ZipformerEncoder( + encoder_layer, + num_encoder_layers[i], + dropout, + warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1) + ) - def get_feature_mask( + if zipformer_downsampling_factors[i] != 1: + assert i > 0, "First zipformer layer cannot use downsampling" + encoder = DownsampledZipformerEncoder( + encoder, + input_dim=encoder_dims[i-1], + output_dim=encoder_dims[i], + downsample=zipformer_downsampling_factors[i], + ) + encoders.append(encoder) + self.encoders = nn.ModuleList(encoders) + + + def get_feature_masks( self, - x: torch.Tensor) -> Tuple[Union[float, Tensor], Union[float, Tensor]]: + x: torch.Tensor) -> List[Union[float, Tensor]]: """ - In eval mode, returns 1.0; in training mode, returns two randomized feature masks - for the 1st and second encoders (which may run at different frame rates). + In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of + randomized feature masks, one per encoder. On e.g. 15% of frames, these masks will zero out all enocder dims larger than some supplied number, e.g. >256, so in effect on those frames we are using a smaller encoer dim. @@ -156,40 +152,46 @@ class Zipformer(EncoderInterface): Args: x: the embeddings (needed for the shape and dtype and device), of shape - (num_frames, batch_size, d_model0) + (num_frames, batch_size, encoder_dims0) """ + num_encoders = len(self.encoder_dims) if not self.training: - return 1.0, 1.0 + return [ 1.0 ] * num_encoders + + (num_frames0, batch_size, _encoder_dims0) = x.shape + + + assert self.encoder_dims[0] == _encoder_dims0 + + max_downsampling_factor = max(self.zipformer_downsampling_factors) + + num_frames_max = (num_frames0 + max_downsampling_factor - 1) - d_model0, d_model1 = self.d_model - (num_frames0, batch_size, _d_model0) = x.shape - assert d_model0 == _d_model0 - ds = self.zipformer_subsampling_factor - num_frames1 = ((num_frames0 + ds - 1) // ds) - # on this proportion of the frames, drop out the extra features above - # self.encoder_unmasked_dim. feature_mask_dropout_prob = 0.15 - # frame_mask1 shape: (num_frames1, batch_size, 1) - frame_mask1 = (torch.rand(num_frames1, batch_size, 1, device=x.device) > - feature_mask_dropout_prob).to(x.dtype) + # frame_mask_max shape: (num_frames_max, batch_size, 1) + frame_mask_max = (torch.rand(num_frames_max, batch_size, 1, + device=x.device) > + feature_mask_dropout_prob).to(x.dtype) + feature_masks = [] + for i in range(num_encoders): + ds = self.zipformer_downsampling_factors[i] + upsample_factor = (max_downsampling_factor // ds) - feature_mask1 = torch.ones(num_frames1, batch_size, self.d_model[1], - dtype=x.dtype, device=x.device) - feature_mask1[:, :, self.encoder_unmasked_dim:] *= frame_mask1 + frame_mask = (frame_mask_max.unsqueeze(1).expand(num_frames_max, upsample_factor, + batch_size, 1) + .reshape(num_frames_max * upsample_factor, batch_size, 1)) + num_frames = (num_frames0 + ds - 1) // ds + frame_mask = frame_mask[:num_frames] + feature_mask = torch.ones(num_frames, batch_size, self.encoder_dims[i], + dtype=x.dtype, device=x.device) + u = self.encoder_unmasked_dims[i] + feature_mask[:, :, u:] *= frame_mask + feature_masks.append(feature_mask) - - # frame_mask0 shape: (num_frames0, batch_size, 1) - frame_mask0 = frame_mask1.unsqueeze(1).expand(num_frames1, ds, batch_size, 1).reshape( - num_frames1 * ds, batch_size, 1)[:num_frames0] - - feature_mask0 = torch.ones(num_frames0, batch_size, self.d_model[0], - dtype=x.dtype, device=x.device) - feature_mask0[:, :, self.encoder_unmasked_dim:] *= frame_mask0 - - return feature_mask0, feature_mask1 + return feature_masks def forward( @@ -204,7 +206,7 @@ class Zipformer(EncoderInterface): `x` before padding. Returns: Return a tuple containing 2 tensors: - - embeddings: its shape is (batch_size, output_seq_len, d_model) + - embeddings: its shape is (batch_size, output_seq_len, encoder_dims[-1]) - lengths, a tensor of shape (batch_size,) containing the number of frames in `embeddings` before padding. """ @@ -219,18 +221,14 @@ class Zipformer(EncoderInterface): assert x.size(0) == lengths.max().item() mask = make_pad_mask(lengths) - feature_mask0, feature_mask1 = self.get_feature_mask(x) + feature_masks = self.get_feature_masks(x) - # x1: - x1 = self.encoder1( - x, feature_mask=feature_mask0, src_key_padding_mask=mask, - ) # (T, N, C) where C == d_model[0] + for i, module in enumerate(self.encoders): + ds = self.zipformer_downsampling_factors[i] + x = module(x, + feature_mask=feature_masks[i], + src_key_padding_mask=None if mask is None else mask[...,::ds]) - x2 = self.encoder2( - x1, feature_mask=feature_mask1, src_key_padding_mask=mask, - ) # (T, N, C) where C == d_model[1] - - x = self.out_combiner(x1, x2) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) @@ -556,8 +554,8 @@ class ZipformerEncoder(nn.Module): class DownsampledZipformerEncoder(nn.Module): r""" DownsampledZipformerEncoder is a zipformer encoder evaluated at a reduced frame rate, - after convolutional downsampling, and then upsampled again at the output - so that the output has the same shape as the input. + after convolutional downsampling, and then upsampled again at the output, and combined + with the origin input, so that the output has the same shape as the input. """ def __init__(self, encoder: nn.Module, @@ -569,6 +567,67 @@ class DownsampledZipformerEncoder(nn.Module): self.downsample = AttentionDownsample(input_dim, output_dim, downsample) self.encoder = encoder self.upsample = SimpleUpsample(output_dim, downsample) + self.out_combiner = SimpleCombiner(input_dim, + output_dim) + + + def forward(self, + src: Tensor, + feature_mask: Union[Tensor, float] = 1.0, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + r"""Downsample, go through encoder, upsample. + + Args: + src: the sequence to the encoder (required). + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer. feature_mask is expected to be already downsampled by + self.downsample_factor. + mask: the mask for the src sequence (optional). CAUTION: we need to downsample + this, if we are to support it. Won't work correctly yet. + src_key_padding_mask: the mask for the src keys per batch (optional). Should + be downsampled already. + + Shape: + src: (S, N, E). + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + Returns: output of shape (S, N, F) where F is the number of output features + (output_dim to constructor) + """ + src_orig = src + src = self.downsample(src) + ds = self.downsample_factor + if mask is not None: + mask = mask[::ds,::ds] + + src = self.encoder( + src, feature_mask=feature_mask, mask=mask, src_key_padding_mask=mask, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[:src_orig.shape[0]] + + return self.out_combiner(src_orig, src) + + +class DownsamplingZipformerEncoder(nn.Module): + r""" + DownsamplingZipformerEncoder is a zipformer encoder that downsamples its input + by a specified factor before feeding it to the zipformer layers. + """ + def __init__(self, + encoder: nn.Module, + input_dim: int, + output_dim: int, + downsample: int): + super(DownsampledZipformerEncoder, self).__init__() + self.downsample_factor = downsample + self.downsample = AttentionDownsample(input_dim, output_dim, downsample) + self.encoder = encoder def forward(self, @@ -608,10 +667,6 @@ class DownsampledZipformerEncoder(nn.Module): src = self.encoder( src, feature_mask=feature_mask, mask=mask, src_key_padding_mask=mask, ) - src = self.upsample(src) - # remove any extra frames that are not a multiple of downsample_factor - src = src[:src_orig.shape[0]] - return src @@ -1642,7 +1697,7 @@ def _test_zipformer_main(): # Just make sure the forward pass runs. c = Zipformer( - num_features=feature_dim, d_model=(64,96), encoder_unmasked_dim=64, nhead=(4,4) + num_features=feature_dim, encoder_dims=(64,96), encoder_unmasked_dims=(48,64), nhead=(4,4) ) batch_size = 5 seq_len = 20