From a9f950a1f7c2dabfe4a561d17d8205e171241262 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 3 Oct 2022 22:49:32 +0800 Subject: [PATCH] Make the scaling factors more global and the randomness of dropout more random --- .../pruned_transducer_stateless7/conformer.py | 138 ++++++++++++------ 1 file changed, 94 insertions(+), 44 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 30e387ce8..472d8c921 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -60,7 +60,7 @@ class Conformer(EncoderInterface): dim_feedforward: int = 2048, num_encoder_layers: int = 12, dropout: float = 0.1, - layer_dropout: float = 0.333, + layer_dropout: float = 0.25, cnn_module_kernel: int = 31, aux_layer_period: int = 3, ) -> None: @@ -153,7 +153,6 @@ class ConformerEncoderLayer(nn.Module): >>> pos_emb = torch.rand(32, 19, 512) >>> out = encoder_layer(src, pos_emb) """ - def __init__( self, d_model: int, @@ -193,12 +192,8 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_final = BasicNorm(d_model) - # scale_alpha relates to a scale that can help work around layerdrop during training. - self.scale_alpha = torch.nn.Parameter(torch.tensor(0.0)) - # try to ensure the output is close to zero-mean (or at least, zero-median). self.balancer = ActivationBalancer( d_model, channel_dim=-1, @@ -207,7 +202,6 @@ class ConformerEncoderLayer(nn.Module): max_var_per_eig=0.2, ) - def forward( self, src: Tensor, @@ -216,8 +210,8 @@ class ConformerEncoderLayer(nn.Module): src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, warmup: float = 1.0, - batch_split: Optional[bool] = None, - layerdrop_indicator: float = 1.0, + layerdrop_mask: Optional[List[float]] = None, + layerdrop_scales: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: """ Pass the input through the encoder layer. @@ -232,12 +226,12 @@ class ConformerEncoderLayer(nn.Module): warmup: controls selective bypass of of layers; if < 1.0, we will bypass layers more frequently. batch_split: if not None, this layer will only be applied to - part of the batch. if True we apply it to the first half of the batch - elements, otherwise to the second half. - layerdrop_indicator: a float. It is supposed to be 1.0 if nothing is dropped out, - and 0.0 if something is dropped out. You don't have to set this directly, - it is set internally if you provide the batch_split option as non-None. - + layerdrop_mask: if None or [1.0, 1.0] then we do the computation as normal. If + [1.0, 0.0] or [0.0, 1.0], we will only do this computation for the first or + second half of the batch respectively, and just copy the input for the other + half. + layerdrop_scales: an optional Tensor of shape (batch_size, 1) that will be used as a scale + on the change in the embeddings made by this layer. Shape: src: (S, N, E). @@ -246,8 +240,9 @@ class ConformerEncoderLayer(nn.Module): src_key_padding_mask: (N, S). S is the source sequence length, N is the batch size, E is the feature number """ - if batch_split is not None: - process_first_half = batch_split + if layerdrop_mask not in [ None, [1.0, 1.0] ]: + assert layerdrop_mask in [ [1.0, 0.0], [0.0, 1.0] ] + process_first_half = (layerdrop_mask == [1.0, 0.0]) batch_size = src.shape[1] mid = batch_size // 2 @@ -257,19 +252,21 @@ class ConformerEncoderLayer(nn.Module): attn_scores_in = torch.zeros(1, 1, 1, 1, device=src.device, dtype=src.dtype).expand( batch_size, seq_len, seq_len, num_heads) - attn_scores_a, attn_scores_b = attn_scores_in[:mid], attn_scores_in[mid:] src_a, src_b = src[:, :mid], src[:, mid:] key_padding_a, key_padding_b = src_key_padding_mask[:mid], src_key_padding_mask[mid:], + layerdrop_scales_a, layerdrop_scales_b = layerdrop_scales[:mid], layerdrop_scales[mid:] if process_first_half: src_a, attn_scores_a = self.forward(src_a, pos_emb, attn_scores_a, src_mask, - key_padding_a, warmup, batch_split=None, - layerdrop_indicator=0.0) + key_padding_a, warmup, + layerdrop_mask=None, + layerdrop_scales=layerdrop_scales_a) else: src_b, attn_scores_b = self.forward(src_b, pos_emb, attn_scores_b, src_mask, - key_padding_b, warmup, batch_split=None, - layerdrop_indicator=0.0) + key_padding_b, warmup, + layerdrop_mask=None, + layerdrop_scales=layerdrop_scales_b) return torch.cat((src_a, src_b), dim=1), torch.cat((attn_scores_a, attn_scores_b), dim=0) @@ -304,11 +301,11 @@ class ConformerEncoderLayer(nn.Module): src = self.norm_final(self.balancer(src)) - if alpha != 1.0 or layerdrop_indicator != 1.0 or self.training: + if alpha != 1.0 or layerdrop_scales is not None: # the if(self.training) part is to ensure we have a derivative for # self.scale_alpha. src_offset = src - src_orig - scale = alpha * (1.0 + self.scale_alpha * (1.0 - layerdrop_indicator)) + scale = alpha * (1.0 if layerdrop_scales is None else layerdrop_scales) src = src_orig + src_offset * scale return src, attn_scores_out @@ -334,7 +331,7 @@ class ConformerEncoder(nn.Module): encoder_layer: nn.Module, num_layers: int, aux_layers: List[int], - layer_dropout: float = 0.333 + layer_dropout: float = 0.25 ) -> None: super().__init__() assert 0 < layer_dropout < 0.5 @@ -345,6 +342,9 @@ class ConformerEncoder(nn.Module): ) self.num_layers = num_layers + self.layerdrop_scale_mat = nn.Parameter(0.01 * torch.randn(num_layers, num_layers)) + + assert num_layers - 1 not in aux_layers self.aux_layers = set(aux_layers + [num_layers - 1]) @@ -355,6 +355,72 @@ class ConformerEncoder(nn.Module): random_prob=0.333, ) + def get_layerdrop_info(self, + batch_size: int) -> Tuple[Tensor, Optional[Tensor]]: + """ + Gets some random information that dictates layer dropout configuration. + Args: + batch_size: the number of sequences in the batch + Returns: + (layerdrop_mask, layerdrop_scales) + where: + layerdrop_mask is a CPU tensor of shape (num_layers, 2) where the 2 represents + two halves of the batch, containing 1.0 for positions to be evaluated and 0.0 + for positions not to be evaluated. It has constraints: at least one of two + halves of each layer must be evaluated, and successive layers of the same half + pmust be evaluated. + + layerdrop_scales is a learned Tensor of shape (num_layers, batch_size, 1) of the form: + 1.0 + [learned matrix * (1.0 - layerdrop_scale)] + where layerdrop_scale is 1.0 for layers that computed, for this half, and + 0.0 for layers not computed. This is intended to learn that layers neighboring + layers that were not computed should get a higher scale to "make up" for the missing + computation. + The reason for the specific functional form is to constrain so that if everything + is computed (layerdrop_scale is all 1.0), this is constrained to be 1.0, to avoid + introducing redundant degrees of freedom. + """ + num_layers = self.num_layers + + layerdrop_mask = torch.ones(num_layers, 2, device='cpu') + + if not self.training or batch_size == 1: + return layerdrop_mask, None + + halves_to_drop = int(2 * num_layers * self.layer_dropout) + for _ in range(halves_to_drop): + while True: + r = random.randrange(0, 2 * num_layers) + i = r // 2 + j = r % 2 + if layerdrop_mask[i, j - 1] == 0.0: + # This position cannot be set to 0.0 because the other + # half of the batch is already 0.0 (not computed). This would lead to + # one layer not having a gradient. + continue + if ((i > 0 and layerdrop_mask[i-1, j] == 0.0) or + (i + 1 < num_layers and layerdrop_mask[i+1, j] == 0.0)): + # This position cannot be set to False because the preceding + # or following position for this same half of the batch is + # already set to False + continue + layerdrop_mask[i, j] = 0.0 + break + + # layerdrop_scales: currently shape is (2, num_layers) + device = self.layerdrop_scale_mat.device + layerdrop_scales_tmp = 1.0 + torch.matmul(self.layerdrop_scale_mat, + 1.0 - layerdrop_mask.to(device)) + + layerdrop_scales = torch.empty(num_layers, batch_size, 1, device=device) + mid = batch_size // 2 + + layerdrop_scales[:, :mid, 0] = layerdrop_scales_tmp[:,0:1] # shape: (num_layers, 1) + layerdrop_scales[:, mid:, 0] = layerdrop_scales_tmp[:,1:2] # shape: (num_layers, 1) + + return layerdrop_mask, layerdrop_scales + + def forward( self, src: Tensor, @@ -401,30 +467,13 @@ class ConformerEncoder(nn.Module): feature_mask[..., feature_unmasked_dim:] *= frame_mask # deal with layer dropout. - batch_size = src.shape[1] - if not self.training or batch_size == 1: - dropped_layer_pairs = set() # empty set. - else: - num_layer_pairs = len(self.layers) // 2 - layer_pairs = list(range(num_layer_pairs)) - random.shuffle(layer_pairs) - # the * 2 is because we only drop out one layer from each pair: - # half for one half of the batch and the other half for the other. - num_dropped_pairs = int(self.layer_dropout * 2 * num_layer_pairs) - dropped_layer_pairs = set(layer_pairs[:num_dropped_pairs]) + layerdrop_mask, layerdrop_scales = self.get_layerdrop_info(batch_size=src.shape[1]) - rand_bool = (random.random() < 0.5) - src = src * feature_mask for i, mod in enumerate(self.layers): - if i // 2 not in dropped_layer_pairs: - batch_split = None # no layer dropout - else: - batch_split = rand_bool if i % 2 == 0 else not rand_bool - output, attn_scores = mod( output, pos_emb, @@ -432,7 +481,8 @@ class ConformerEncoder(nn.Module): src_mask=mask, src_key_padding_mask=src_key_padding_mask, warmup=warmup, - batch_split=batch_split, + layerdrop_mask=layerdrop_mask[i].tolist(), # [ 1.0, 1.0 ], [0.0, 1.0] or [1.0, 0.0] + layerdrop_scales=layerdrop_scales[i], # tensor of scales of shape (batch_size, 1) ) output = output * feature_mask if i in self.aux_layers: