diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 8b9cd9982..34b9d00f0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -215,6 +215,7 @@ class ConformerEncoderLayer(nn.Module): attn_scores_in: Optional[Tensor] = None, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, + feature_mask: Optional[Tensor] = None, warmup: float = 1.0, ) -> Tensor: """ @@ -227,6 +228,8 @@ class ConformerEncoderLayer(nn.Module): passed from layer to layer. src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). + feature_mask: a mask of shape (S, N, E), that randomly zeroes out + some of the features on each frame. warmup: controls selective bypass of of layers; if < 1.0, we will bypass layers more frequently. @@ -235,6 +238,7 @@ class ConformerEncoderLayer(nn.Module): pos_emb: (N, 2*S-1, E) src_mask: (S, S). src_key_padding_mask: (N, S). + feature_mask: (S, N, E) S is the source sequence length, N is the batch size, E is the feature number """ src_orig = src @@ -275,6 +279,9 @@ class ConformerEncoderLayer(nn.Module): if alpha != 1.0: src = alpha * src + (1 - alpha) * src_orig + if feature_mask is not None: + src = src * feature_mask + return src, attn_scores_out @@ -344,6 +351,20 @@ class ConformerEncoder(nn.Module): outputs = [] attn_scores = None + + if self.training: + # feature mask. + # on 0.25 of the frames, drop out the extra features [force a bottleneck.] + feature_mask_dropout_prob = 0.25 + feature_unmasked_dim = 256 # hardcode dim for now, 1st 256 are non-masked. + + feature_mask = torch.ones_like(src) # S, N, E + # is_masked_frame is 0 with probability `feature_mask_dropout_prob` + is_masked_frame = (torch.rand_like(src[...,:1]) > feature_mask_dropout_prob).to(src.dtype) + feature_mask[..., feature_unmasked_dim:] *= is_masked_frame + else: + feature_mask = None + for i, mod in enumerate(self.layers): output, attn_scores = mod( output, @@ -351,6 +372,7 @@ class ConformerEncoder(nn.Module): attn_scores, src_mask=mask, src_key_padding_mask=src_key_padding_mask, + feature_mask=feature_mask, warmup=warmup, ) if i in self.aux_layers: