diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index ffaa6660f..35386c9df 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -41,7 +41,7 @@ class Conformer(EncoderInterface): Args: num_features (int): Number of input features subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) - d_model (int): (attention_dim1, attention_dim2, output_dim) + d_model (int): embedding dimension nhead (int): number of head dim_feedforward (int): feedforward dimention num_encoder_layers (int): number of encoder layers @@ -56,7 +56,8 @@ class Conformer(EncoderInterface): num_features: int, subsampling_factor: int = 4, conformer_subsampling_factor: int = 4, - d_model: Tuple[int] = (256, 384, 512), + d_model: Tuple[int] = (384, 384), + encoder_unmasked_dim: int = 256, nhead: Tuple[int] = (8, 8), feedforward_dim: Tuple[int] = (1536, 2048), num_encoder_layers: Tuple[int] = (12, 12), @@ -67,6 +68,13 @@ class Conformer(EncoderInterface): self.num_features = num_features self.subsampling_factor = subsampling_factor + self.encoder_unmasked_dim = encoder_unmasked_dim + assert 0 < d_model[0] <= d_model[1] + self.d_model = d_model + self.conformer_subsampling_factor = conformer_subsampling_factor + + assert encoder_unmasked_dim <= d_model[0] and encoder_unmasked_dim <= d_model[1] + if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") @@ -112,6 +120,64 @@ class Conformer(EncoderInterface): self.out_combiner = SimpleCombiner(d_model[0], d_model[1]) + def get_feature_mask( + self, + x: torch.Tensor) -> Tuple[Union[float, Tensor], Union[float, Tensor]]: + """ + In eval mode, returns 1.0; in training mode, returns two randomized feature masks + for the 1st and second encoders (which may run at different frame rates). + On e.g. 15% of frames, these masks will zero out all enocder dims larger than + some supplied number, e.g. >256, so in effect on those frames we are using + a smaller encoer dim. + + We generate the random masks at this level because we want the 2 masks to 'agree' + all the way up the encoder stack. This will mean that the 1st mask will have + mask values repeated self.conformer_subsampling_factor times. + + Args: + x: the embeddings (needed for the shape and dtype and device), of shape + (num_frames, batch_size, d_model0) + """ + if not self.training: + return 1.0, 1.0 + + d_model0, d_model1 = self.d_model + (num_frames0, batch_size, _d_model0) = x.shape + assert d_model0 == _d_model0 + ds = self.conformer_subsampling_factor + num_frames1 = ((num_frames0 + ds - 1) // ds) + + # on this proportion of the frames, drop out the extra features above + # self.encoder_unmasked_dim. + feature_mask_dropout_prob = 0.15 + + # we only apply the random frame masking on 90% of sequences; we leave the remaining 10% + # un-masked so that the model has seen un-masked data. + sequence_mask_dropout_prob = 0.9 + + # frame_mask is 0 with probability `feature_mask_dropout_prob` + # frame_mask1 shape: (num_frames1, batch_size, 1) + frame_mask1 = torch.logical_or( + torch.rand(num_frames1, batch_size, 1, device=x.device) > feature_mask_dropout_prob, + torch.rand(1, batch_size, 1, device=x.device) > sequence_mask_dropout_prob).to(x.dtype) + + feature_mask1 = torch.ones(num_frames1, batch_size, self.d_model[1], + dtype=x.dtype, device=x.device) + feature_mask1[:, :, self.encoder_unmasked_dim:] *= frame_mask1 + + + # frame_mask0 shape: (num_frames0, batch_size, 1) + frame_mask0 = frame_mask1.unsqueeze(1).expand(num_frames1, ds, batch_size, 1).reshape( + num_frames1 * ds, batch_size, 1)[:num_frames0] + + print("frame_mask0 = ", frame_mask0.squeeze(-1)) + + feature_mask0 = torch.ones(num_frames0, batch_size, self.d_model[0], + dtype=x.dtype, device=x.device) + feature_mask0[:, :, self.encoder_unmasked_dim:] *= frame_mask0 + + return feature_mask0, feature_mask1 + def forward( self, x: torch.Tensor, x_lens: torch.Tensor, @@ -140,18 +206,19 @@ class Conformer(EncoderInterface): assert x.size(0) == lengths.max().item() mask = make_pad_mask(lengths) + feature_mask0, feature_mask1 = self.get_feature_mask(x) + # x1: x1 = self.encoder1( - x, src_key_padding_mask=mask, + x, feature_mask=feature_mask0, src_key_padding_mask=mask, ) # (T, N, C) where C == d_model[0] x2 = self.encoder2( - x1, src_key_padding_mask=mask, + x1, feature_mask=feature_mask1, src_key_padding_mask=mask, ) # (T, N, C) where C == d_model[1] x = self.out_combiner(x1, x2) - x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) return x, lengths @@ -319,6 +386,7 @@ class ConformerEncoder(nn.Module): def forward( self, src: Tensor, + feature_mask: Union[Tensor, float] = 1.0, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: @@ -326,6 +394,8 @@ class ConformerEncoder(nn.Module): Args: src: the sequence to the encoder (required). + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer. mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). @@ -344,33 +414,8 @@ class ConformerEncoder(nn.Module): outputs = [] attn_scores = None - - # deal with feature masking. - if not self.training: - feature_mask = 1.0 - else: - # feature mask. - # on 0.25 of the frames, drop out the extra features [force a bottleneck.] - feature_mask_dropout_prob = 0.15 - feature_unmasked_dim = 256 # hardcode dim for now, 1st 256 are non-masked. - - feature_mask = torch.ones_like(src) # S, N, E - # frame_mask is 0 with probability `feature_mask_dropout_prob` - # frame_mask shape: (S, N, 1) - frame_mask = (torch.rand_like(src[...,:1]) > feature_mask_dropout_prob).to(src.dtype) - - # for 10% of sequences, make the frame mask always-1, i.e. don't drop out any of - # the frames. This is to make sure the model sometimes "sees" the same types of - # un-perturbed sequences that it will see in test time. - frame_mask = torch.logical_or(frame_mask, - torch.rand_like(src[:,:1,:1]) < 0.1) - - feature_mask[..., feature_unmasked_dim:] *= frame_mask - - output = output * feature_mask - num_layers = len(self.layers) indexes = list(range(num_layers)) if self.training: @@ -417,6 +462,7 @@ class DownsampledConformerEncoder(nn.Module): def forward(self, src: Tensor, + feature_mask: Union[Tensor, float] = 1.0, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: @@ -424,6 +470,9 @@ class DownsampledConformerEncoder(nn.Module): Args: src: the sequence to the encoder (required). + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer. feature_mask is expected to be already downsampled by + self.downsample_factor. mask: the mask for the src sequence (optional). CAUTION: we need to downsample this, if we are to support it. Won't work correctly yet. src_key_padding_mask: the mask for the src keys per batch (optional). @@ -446,7 +495,7 @@ class DownsampledConformerEncoder(nn.Module): src_key_padding_mask = src_key_padding_mask[::ds] src = self.encoder( - src, src_key_padding_mask=mask, + src, feature_mask=feature_mask, mask=mask, src_key_padding_mask=mask, ) src = self.upsample(src) # remove any extra frames that are not a multiple of downsample_factor @@ -540,7 +589,10 @@ class SimpleCombiner(torch.nn.Module): """ A very simple way of combining 2 vectors of 2 different dims, via a learned weighted combination in the shared part of the dim. - + Args: + dim1: the dimension of the first input, e.g. 256 + dim2: the dimension of the second input, e.g. 384. Require dim2 >= dim1. + The output will have the same dimension as dim2. """ def __init__(self, dim1: int, @@ -1381,7 +1433,7 @@ def _test_conformer_main(): # Just make sure the forward pass runs. c = Conformer( - num_features=feature_dim, d_model=(64,96,128), nhead=(4,4) + num_features=feature_dim, d_model=(64,96), encoder_unmasked_dim=64, nhead=(4,4) ) batch_size = 5 seq_len = 20 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index fda28302f..cc1a21f64 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -118,6 +118,15 @@ def add_model_arguments(parser: argparse.ArgumentParser): "and the output dim of the encoder", ) + parser.add_argument( + "--encoder-unmasked-dim", + type=int, + default=256, + help="Unmasked dimension in the encoder, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse." + ) + parser.add_argument( "--conformer-subsampling-factor", type=int, @@ -416,6 +425,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: subsampling_factor=params.subsampling_factor, conformer_subsampling_factor=params.conformer_subsampling_factor, d_model=to_int_list(params.encoder_dims), + encoder_unmasked_dim=params.encoder_unmasked_dim, nhead=to_int_list(params.nhead), feedforward_dim=to_int_list(params.feedforward_dims), num_encoder_layers=to_int_list(params.num_encoder_layers),