diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 4cb244769..8108ce7f7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -60,7 +60,7 @@ class Conformer(EncoderInterface): dim_feedforward: int = 2048, num_encoder_layers: int = 12, dropout: float = 0.1, - layer_dropout: float = 0.075, + layer_dropout: float = 0.333, cnn_module_kernel: int = 31, aux_layer_period: int = 3, ) -> None: @@ -85,13 +85,13 @@ class Conformer(EncoderInterface): nhead, dim_feedforward, dropout, - layer_dropout, cnn_module_kernel, ) self.encoder = ConformerEncoder( encoder_layer, num_encoder_layers, aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)), + layer_dropout=layer_dropout, ) @@ -160,13 +160,10 @@ class ConformerEncoderLayer(nn.Module): nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, - layer_dropout: float = 0.075, cnn_module_kernel: int = 31, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.layer_dropout = layer_dropout - self.d_model = d_model self.self_attn = RelPositionMultiheadAttention( @@ -215,53 +212,80 @@ class ConformerEncoderLayer(nn.Module): def forward( self, src: Tensor, - feature_mask: Union[Tensor, float], pos_emb: Tensor, attn_scores_in: Optional[Tensor] = None, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, warmup: float = 1.0, - ) -> Tensor: + batch_split: Optional[bool] = None, + layerdrop_indicator: float = 1.0, + ) -> Tuple[Tensor, Tensor]: """ Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). pos_emb: Positional embedding tensor (required). - attn_scores_in: something with the dimension fo attention weights (bsz * num_heads, len, len) that is + attn_scores_in: something with the dimension fo attention weights (bsz, len, len, num_heads) that is passed from layer to layer. src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). - feature_mask: a mask of shape (S, N, E), that randomly zeroes out - some of the features on each frame. warmup: controls selective bypass of of layers; if < 1.0, we will bypass layers more frequently. + batch_split: if not None, this layer will only be applied to + part of the batch. if True we apply it to the first half of the batch + elements, otherwise to the second half. + layerdrop_indicator: a float. It is supposed to be 1.0 if nothing is dropped out, + and 0.0 if something is dropped out. You don't have to set this directly, + it is set internally if you provide the batch_split option as non-None. + Shape: src: (S, N, E). - feature_mask: float, or (S, N, 1) pos_emb: (N, 2*S-1, E) src_mask: (S, S). src_key_padding_mask: (N, S). S is the source sequence length, N is the batch size, E is the feature number """ + if batch_split is not None: + process_first_half = batch_split + batch_size = src.shape[1] + mid = batch_size // 2 + + if attn_scores_in is None: + seq_len = src.shape[0] + num_heads = self.self_attn.num_heads + attn_scores_in = torch.zeros(1, 1, 1, 1, device=src.device, dtype=src.dtype).expand( + batch_size, seq_len, seq_len, num_heads) + + + attn_scores_a, attn_scores_b = attn_scores_in[:mid], attn_scores_in[mid:] + src_a, src_b = src[:, :mid], src[:, mid:] + key_padding_a, key_padding_b = src_key_padding_mask[:mid], src_key_padding_mask[mid:], + + if process_first_half: + src_a, attn_scores_a = self.forward(src_a, pos_emb, attn_scores_a, src_mask, + key_padding_a, warmup, batch_split=None, + layerdrop_indicator=0.0) + else: + src_b, attn_scores_b = self.forward(src_b, pos_emb, attn_scores_b, src_mask, + key_padding_b, warmup, batch_split=None, + layerdrop_indicator=0.0) + + return torch.cat((src_a, src_b), dim=1), torch.cat((attn_scores_a, attn_scores_b), dim=0) + + src_orig = src + warmup_scale = min(0.1 + warmup, 1.0) # alpha = 1.0 means fully use this encoder layer, 0.0 would mean # completely bypass it. - if self.training: - alpha = ( - warmup_scale - if torch.rand(()).item() <= (1.0 - self.layer_dropout) - else 0.1 - ) - else: - alpha = 1.0 + alpha = warmup_scale if self.training else 1.0 # macaron style feed forward module src = src + self.feed_forward_macaron_scale(self.feed_forward_macaron(src), - feature_mask) + layerdrop_indicator) # multi-headed self-attention module src_att, _, attn_scores_out = self.self_attn( @@ -271,18 +295,18 @@ class ConformerEncoderLayer(nn.Module): attn_mask=src_mask, key_padding_mask=src_key_padding_mask, ) - src = src + self.self_attn_scale(src_att, feature_mask) + src = src + self.self_attn_scale(src_att, layerdrop_indicator) # convolution module src = src + self.conv_scale(self.conv_module(src, src_key_padding_mask=src_key_padding_mask), - feature_mask) + layerdrop_indicator) # feed forward module src = src + self.feed_forward_scale(self.feed_forward(src), - feature_mask) + layerdrop_indicator) - src = self.final_scale(src, feature_mask) + src = self.final_scale(src, layerdrop_indicator) src = self.norm_final(self.balancer(src)) @@ -312,8 +336,12 @@ class ConformerEncoder(nn.Module): encoder_layer: nn.Module, num_layers: int, aux_layers: List[int], + layer_dropout: float = 0.333 ) -> None: super().__init__() + assert 0 < layer_dropout < 0.5 + self.layer_dropout = layer_dropout + self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] ) @@ -359,34 +387,56 @@ class ConformerEncoder(nn.Module): attn_scores = None - if self.training: + # deal with feature masking. + if not self.training: + feature_mask = 1.0 + else: # feature mask. # on 0.25 of the frames, drop out the extra features [force a bottleneck.] feature_mask_dropout_prob = 0.15 feature_unmasked_dim = 256 # hardcode dim for now, 1st 256 are non-masked. - full_feature_mask = torch.ones_like(src) # S, N, E - # feature_mask is 0 with probability `feature_mask_dropout_prob` - # feature_mask shape: (S, N, 1) - feature_mask = (torch.rand_like(src[...,:1]) > feature_mask_dropout_prob).to(src.dtype) - full_feature_mask[..., feature_unmasked_dim:] *= feature_mask - else: - feature_mask = 1.0 - full_feature_mask = 1.0 + feature_mask = torch.ones_like(src) # S, N, E + # frame_mask is 0 with probability `feature_mask_dropout_prob` + # frame_mask shape: (S, N, 1) + frame_mask = (torch.rand_like(src[...,:1]) > feature_mask_dropout_prob).to(src.dtype) + feature_mask[..., feature_unmasked_dim:] *= frame_mask - src = src * full_feature_mask + # deal with layer dropout. + batch_size = src.shape[1] + if not self.training or batch_size == 1: + dropped_layer_pairs = set() # empty set. + else: + num_layer_pairs = len(self.layers) // 2 + layer_pairs = list(range(num_layer_pairs)) + random.shuffle(layer_pairs) + # the * 2 is because we only drop out one layer from each pair: + # half for one half of the batch and the other half for the other. + num_dropped_pairs = int(self.layer_dropout * 2 * num_layer_pairs) + dropped_layer_pairs = set(layer_pairs[:num_dropped_pairs]) + + + rand_bool = (random.random() < 0.5) + + src = src * feature_mask for i, mod in enumerate(self.layers): + + if i // 2 not in dropped_layer_pairs: + batch_split = None # no layer dropout + else: + batch_split = rand_bool if i % 2 == 0 else not rand_bool + output, attn_scores = mod( output, - feature_mask, pos_emb, attn_scores, src_mask=mask, src_key_padding_mask=src_key_padding_mask, warmup=warmup, + batch_split=batch_split, ) - output = output * full_feature_mask + output = output * feature_mask if i in self.aux_layers: outputs.append(output)