From 5fe8cb134f4d19b83d126a2ec292d73b91c66b1b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 4 Oct 2022 22:19:44 +0800 Subject: [PATCH] Remove final combination; implement layer drop that drops the final layers. --- .../pruned_transducer_stateless7/conformer.py | 170 ++++++------------ .../ASR/pruned_transducer_stateless7/train.py | 3 +- 2 files changed, 61 insertions(+), 112 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 94bb0aa7b..2c74a23a6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -62,7 +62,6 @@ class Conformer(EncoderInterface): dropout: float = 0.1, layer_dropout: float = 0.25, cnn_module_kernel: int = 31, - aux_layer_period: int = 3, ) -> None: super(Conformer, self).__init__() @@ -90,7 +89,6 @@ class Conformer(EncoderInterface): self.encoder = ConformerEncoder( encoder_layer, num_encoder_layers, - aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)), layer_dropout=layer_dropout, ) @@ -210,8 +208,7 @@ class ConformerEncoderLayer(nn.Module): src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, warmup: float = 1.0, - layerdrop_mask: Optional[List[float]] = None, - layerdrop_scales: Optional[Tensor] = None, + layerdrop_scale: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: """ Pass the input through the encoder layer. @@ -226,11 +223,7 @@ class ConformerEncoderLayer(nn.Module): warmup: controls selective bypass of of layers; if < 1.0, we will bypass layers more frequently. batch_split: if not None, this layer will only be applied to - layerdrop_mask: if None or [1.0, 1.0] then we do the computation as normal. If - [1.0, 0.0] or [0.0, 1.0], we will only do this computation for the first or - second half of the batch respectively, and just copy the input for the other - half. - layerdrop_scales: an optional Tensor of shape (batch_size, 1) that will be used as a scale + layerdrop_scale: an optional Tensor of broadcasting with `src` that will be used as a scale on the change in the embeddings made by this layer. Shape: @@ -240,40 +233,8 @@ class ConformerEncoderLayer(nn.Module): src_key_padding_mask: (N, S). S is the source sequence length, N is the batch size, E is the feature number """ - if layerdrop_mask not in [ None, [1.0, 1.0] ]: - assert layerdrop_mask in [ [1.0, 0.0], [0.0, 1.0] ] - process_first_half = (layerdrop_mask == [1.0, 0.0]) - batch_size = src.shape[1] - mid = batch_size // 2 - - if attn_scores_in is None: - seq_len = src.shape[0] - num_heads = self.self_attn.num_heads - attn_scores_in = torch.zeros(1, 1, 1, 1, device=src.device, dtype=src.dtype).expand( - batch_size, seq_len, seq_len, num_heads) - - attn_scores_a, attn_scores_b = attn_scores_in[:mid], attn_scores_in[mid:] - src_a, src_b = src[:, :mid], src[:, mid:] - key_padding_a, key_padding_b = src_key_padding_mask[:mid], src_key_padding_mask[mid:], - layerdrop_scales_a, layerdrop_scales_b = layerdrop_scales[:mid], layerdrop_scales[mid:] - - if process_first_half: - src_a, attn_scores_a = self.forward(src_a, pos_emb, attn_scores_a, src_mask, - key_padding_a, warmup, - layerdrop_mask=None, - layerdrop_scales=layerdrop_scales_a) - else: - src_b, attn_scores_b = self.forward(src_b, pos_emb, attn_scores_b, src_mask, - key_padding_b, warmup, - layerdrop_mask=None, - layerdrop_scales=layerdrop_scales_b) - - return torch.cat((src_a, src_b), dim=1), torch.cat((attn_scores_a, attn_scores_b), dim=0) - - src_orig = src - warmup_scale = min(0.1 + warmup, 1.0) # alpha = 1.0 means fully use this encoder layer, 0.0 would mean # completely bypass it. @@ -301,11 +262,11 @@ class ConformerEncoderLayer(nn.Module): src = self.norm_final(self.balancer(src)) - if alpha != 1.0 or layerdrop_scales is not None: + if alpha != 1.0 or layerdrop_scale is not None: # the if(self.training) part is to ensure we have a derivative for # self.scale_alpha. src_offset = src - src_orig - scale = alpha * (1.0 if layerdrop_scales is None else layerdrop_scales) + scale = alpha * layerdrop_scale src = src_orig + src_offset * scale return src, attn_scores_out @@ -325,16 +286,18 @@ class ConformerEncoder(nn.Module): >>> pos_emb = torch.rand(32, 19, 512) >>> out = conformer_encoder(src, pos_emb) """ - def __init__( self, encoder_layer: nn.Module, num_layers: int, - aux_layers: List[int], layer_dropout: float = 0.25 ) -> None: super().__init__() assert 0 < layer_dropout < 0.5 + # `count` tracks how many times the forward function has been called + # since we initialized the model (it is not written to disk or read when + # we resume training). It is used for random seeding for layer dropping. + self.count = 0 self.layer_dropout = layer_dropout self.layers = nn.ModuleList( @@ -342,31 +305,25 @@ class ConformerEncoder(nn.Module): ) self.num_layers = num_layers - self.layerdrop_scale_mat = nn.Parameter(0.01 * torch.randn(num_layers, num_layers)) - self.layerdrop_scale_offset = nn.Parameter(torch.ones(num_layers)) + self.to_layerdrop_scales = nn.Sequential( + ScaledLinear(num_layers, 256, initial_scale=0.5), + nn.ReLU(), + ScaledLinear(256, num_layers, initial_scale=0.01)) - assert num_layers - 1 not in aux_layers - self.aux_layers = set(aux_layers + [num_layers - 1]) num_channels = encoder_layer.norm_final.num_channels - self.combiner = AttentionCombine( - num_channels=encoder_layer.d_model, - num_inputs=len(self.aux_layers), - random_prob=0.333, - ) - def get_layerdrop_info(self, - batch_size: int) -> Tuple[Tensor, Optional[Tensor]]: + def get_layerdrop_info(self) -> Tuple[Tensor, Optional[Tensor]]: """ Gets some random information that dictates layer dropout configuration. Args: batch_size: the number of sequences in the batch Returns: - (layerdrop_mask, layerdrop_scales) + (mask, layerdrop_scales) where: - layerdrop_mask is a CPU tensor of shape (num_layers, 2) where the 2 represents - two halves of the batch, containing 1.0 for positions to be evaluated and 0.0 - for positions not to be evaluated. It has constraints: at least one of two + layerdrop_mask is a CPU tensor of shape (num_layers,), + containing 1.0 for layers to be evaluated and 0.0 + for layers not to be evaluated. It has constraints: at least one of two halves of each layer must be evaluated, and successive layers of the same half pmust be evaluated. @@ -382,42 +339,39 @@ class ConformerEncoder(nn.Module): """ num_layers = self.num_layers - layerdrop_mask = torch.ones(num_layers, 2, device='cpu') + # This ensures that if we are using multiple worker processes, they all use the same + # random numbers, so they will all take about the same amount of time to process + # the batch. + r = random.Random(self.count) + self.count += 1 - if self.training and batch_size != 1: - halves_to_drop = int(2 * num_layers * self.layer_dropout) - for _ in range(halves_to_drop): - while True: - r = random.randrange(0, 2 * num_layers) - i = r // 2 - j = r % 2 - if layerdrop_mask[i, j - 1] == 0.0: - # This position cannot be set to 0.0 because the other - # half of the batch is already 0.0 (not computed). This would lead to - # one layer not having a gradient. - continue - if ((i > 0 and layerdrop_mask[i-1, j] == 0.0) or - (i + 1 < num_layers and layerdrop_mask[i+1, j] == 0.0)): - # This position cannot be set to False because the preceding - # or following position for this same half of the batch is - # already set to False - continue - layerdrop_mask[i, j] = 0.0 - break + def get_random_mask(): + # 1.0 means don't drop the layer, 0.0 means drop the layer + mask = torch.ones(num_layers, device='cpu') + if self.training: + return mask + r = r.random() + if r < 0.1: + # drop zero layers, to make sure that sometimes we see the complete network. + return mask + final_layers_dropped = 0 + if r < 0.1 + 0.25: + # with prob 0.25: completely drop the last n layers. let n + # be a multiple of 3 (this is what we used to do with aux_layers). + final_layers_dropped = 3 * r.randint(1, num_layers // 3) + mask[-final_layers_dropped:] = 0.0 - # layerdrop_scales: currently shape is (2, num_layers) - device = self.layerdrop_scale_mat.device - layerdrop_scales_tmp = (self.layerdrop_scale_offset.unsqueeze(1) + - torch.matmul(self.layerdrop_scale_mat, - 1.0 - layerdrop_mask.to(device))) + layer_drop_prob = 0.075 + for i in range(final_layers_dropped): + mask[i] = (r.random() > layer_drop_prob) - layerdrop_scales = torch.empty(num_layers, batch_size, 1, device=device) - mid = batch_size // 2 + if mask.sum() == 0.0: + mask[0] = 1.0 + mask = get_random_mask() + device = self.to_layerdrop_scales[0].weight.device + layerdrop_scales = 1.0 + self.to_layerdrop_scales(mask.to(device)) + return mask, layerdrop_scales - layerdrop_scales[:, :mid, 0] = layerdrop_scales_tmp[:,0:1] # shape: (num_layers, 1) - layerdrop_scales[:, mid:, 0] = layerdrop_scales_tmp[:,1:2] # shape: (num_layers, 1) - - return layerdrop_mask, layerdrop_scales def forward( @@ -466,30 +420,24 @@ class ConformerEncoder(nn.Module): feature_mask[..., feature_unmasked_dim:] *= frame_mask # deal with layer dropout. - layerdrop_mask, layerdrop_scales = self.get_layerdrop_info(batch_size=src.shape[1]) + layerdrop_mask, layerdrop_scales = self.get_layerdrop_info() - src = src * feature_mask + output = output * feature_mask for i, mod in enumerate(self.layers): - output, attn_scores = mod( - output, - pos_emb, - attn_scores, - src_mask=mask, - src_key_padding_mask=src_key_padding_mask, - warmup=warmup, - layerdrop_mask=layerdrop_mask[i].tolist(), # [ 1.0, 1.0 ], [0.0, 1.0] or [1.0, 0.0] - layerdrop_scales=layerdrop_scales[i], # tensor of scales of shape (batch_size, 1) - ) - output = output * feature_mask - if i in self.aux_layers: - outputs.append(output) - - output = self.combiner(outputs) - - output = output + 0.0 * attn_scores.sum() # just ensure attn_scores is used in backprop + if layerdrop_mask[i] != 0.0: + output, attn_scores = mod( + output, + pos_emb, + attn_scores, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + warmup=warmup, + layerdrop_scale=layerdrop_scales[i], + ) + output = output * feature_mask return output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index e1090f6fd..176d8f207 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -924,7 +924,8 @@ def run(rank, world_size, args): model.to(device) if world_size > 1: logging.info("Using DDP") - model = DDP(model, device_ids=[rank]) + model = DDP(model, device_ids=[rank], + find_unused_parameters=True) optimizer = ScaledAdam(model.parameters(), lr=params.initial_lr,