From 0685ac792d7920446cd35081df00bc9f90ddc8f0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 6 Oct 2022 12:36:42 +0800 Subject: [PATCH] Remove layer dropout and model-level warmup --- .../pruned_transducer_stateless7/conformer.py | 92 ++----------------- 1 file changed, 8 insertions(+), 84 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 308e2e13d..9dfd682a5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -203,7 +203,6 @@ class ConformerEncoderLayer(nn.Module): attn_scores_in: Optional[Tensor] = None, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - layerdrop_scale: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: """ Pass the input through the encoder layer. @@ -216,8 +215,6 @@ class ConformerEncoderLayer(nn.Module): src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). batch_split: if not None, this layer will only be applied to - layerdrop_scale: an optional Tensor of broadcasting with `src` that will be used as a scale - on the change in the embeddings made by this layer. Shape: src: (S, N, E). @@ -286,76 +283,9 @@ class ConformerEncoder(nn.Module): ) self.num_layers = num_layers - self.to_layerdrop_scales = nn.Sequential( - ScaledLinear(num_layers, 256, initial_scale=0.5), - nn.ReLU(), - ScaledLinear(256, num_layers, initial_scale=0.01)) - num_channels = encoder_layer.norm_final.num_channels - def get_layerdrop_info(self) -> Tuple[Tensor, Optional[Tensor]]: - """ - Gets some random information that dictates layer dropout configuration. - Args: - batch_size: the number of sequences in the batch - Returns: - (mask, layerdrop_scales) - where: - layerdrop_mask is a CPU tensor of shape (num_layers,), - containing 1.0 for layers to be evaluated and 0.0 - for layers not to be evaluated. It has constraints: at least one of two - halves of each layer must be evaluated, and successive layers of the same half - pmust be evaluated. - - layerdrop_scales is a learned Tensor of shape (num_layers, batch_size, 1) of the form: - 1.0 + [learned matrix * (1.0 - layerdrop_scale)] - where layerdrop_scale is 1.0 for layers that computed, for this half, and - 0.0 for layers not computed. This is intended to learn that layers neighboring - layers that were not computed should get a higher scale to "make up" for the missing - computation. - The reason for the specific functional form is to constrain so that if everything - is computed (layerdrop_scale is all 1.0), this is constrained to be 1.0, to avoid - introducing redundant degrees of freedom. - """ - num_layers = self.num_layers - - # This ensures that if we are using multiple worker processes, they all use the same - # random numbers, so they will all take about the same amount of time to process - # the batch. - rng = random.Random(self.count) - self.count += 1 - - def get_random_mask(): - # 1.0 means don't drop the layer, 0.0 means drop the layer - mask = torch.ones(num_layers, device='cpu') - if not self.training: - return mask - r = rng.random() - if r < 0.1: - # drop zero layers, to make sure that sometimes we see the complete network. - return mask - final_layers_dropped = 0 - if r < 0.1 + 0.25: - # with prob 0.25: completely drop the last n layers. let n - # be a multiple of 3 (this is what we used to do with aux_layers). - final_layers_dropped = 3 * rng.randint(1, num_layers // 3) - mask[-final_layers_dropped:] = 0.0 - - layer_drop_prob = 0.075 - for i in range(num_layers - final_layers_dropped): - mask[i] = (rng.random() > layer_drop_prob) - if mask.sum() == 0.0: - mask[0] = 1.0 - return mask - - mask = get_random_mask() - device = self.to_layerdrop_scales[0].weight.device - layerdrop_scales = 1.0 + self.to_layerdrop_scales(mask.to(device)) - if random.random() < 0.05: - logging.info(f"mask={mask}, layerdrop_scales = {layerdrop_scales.to('cpu')}") - return mask, layerdrop_scales - def forward( @@ -402,24 +332,18 @@ class ConformerEncoder(nn.Module): frame_mask = (torch.rand_like(src[...,:1]) > feature_mask_dropout_prob).to(src.dtype) feature_mask[..., feature_unmasked_dim:] *= frame_mask - # deal with layer dropout. - layerdrop_mask, layerdrop_scales = self.get_layerdrop_info() - output = output * feature_mask for i, mod in enumerate(self.layers): - - if layerdrop_mask[i] != 0.0: - output, attn_scores = mod( - output, - pos_emb, - attn_scores, - src_mask=mask, - src_key_padding_mask=src_key_padding_mask, - layerdrop_scale=layerdrop_scales[i], - ) - output = output * feature_mask + output, attn_scores = mod( + output, + pos_emb, + attn_scores, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + ) + output = output * feature_mask return output