diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index c4e1a597f..c987f3000 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -49,6 +49,8 @@ class Conformer(EncoderInterface): layer_dropout (float): layer-dropout rate. cnn_module_kernel (int): Kernel size of convolution module vgg_frontend (bool): whether to use vgg frontend. + warmup (float): number of batches to warm up over (gradually skip + layer bypass) """ def __init__( @@ -63,6 +65,7 @@ class Conformer(EncoderInterface): num_encoder_layers: Tuple[int] = (12, 12), dropout: float = 0.1, cnn_module_kernel: Tuple[int] = (31, 31), + warmup_batches: float = 3000, ) -> None: super(Conformer, self).__init__() @@ -98,6 +101,7 @@ class Conformer(EncoderInterface): encoder_layer1, num_encoder_layers[0], dropout, + warmup_batches, ) encoder_layer2 = ConformerEncoderLayer( d_model[1], @@ -111,6 +115,7 @@ class Conformer(EncoderInterface): encoder_layer2, num_encoder_layers[1], dropout, + warmup_batches, ), input_dim=d_model[0], output_dim=d_model[1], @@ -362,13 +367,19 @@ class ConformerEncoder(nn.Module): original output, in case you need to bypass the RandomCombiner module. """ def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - dropout: float + self, + encoder_layer: nn.Module, + num_layers: int, + dropout: float, + warmup_batches: float, ) -> None: super().__init__() + # keep track of how many times forward() has been called, for purposes of + # 'warmup' + self.register_buffer('count', torch.tensor(0, dtype=torch.int64)) + self.warmup_batches = warmup_batches + self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, dropout) @@ -381,6 +392,18 @@ class ConformerEncoder(nn.Module): num_channels = encoder_layer.norm_final.num_channels + + def get_warmup_value(self) -> float: + """ + Returns a value that is 0 at the start of training and approaches 1.0 after a number of + 'warmup' batches, specified in the constructor. + """ + batch = self.count.item() + if self.training: + self.count += 1 + return min(1.0, batch / self.warmup_batches) + + def forward( self, src: Tensor, @@ -414,30 +437,46 @@ class ConformerEncoder(nn.Module): output = output * feature_mask - num_layers = len(self.layers) - indexes = list(range(num_layers)) - if self.training: - num_to_swap = int((num_layers - 1) * 0.1) - for i in range(num_to_swap): - j = random.randint(0, num_layers - 2) - a,b = indexes[j], indexes[j+1] - indexes[j] = b - indexes[j+1] = a + outputs = [ output ] + # warmup starts at 0 at the beginning of training, reaches 1 at a few + # thousand minibatches, and then stays there. + warmup = self.get_warmup_value() - for i in indexes: - mod = self.layers[i] + def apply_bypass(prev_output: Tensor, output: Tensor, + warmup: float, + min_output_scale: float = 0.1, + max_output_scale: float = 1.0): + output_scale = max(warmup * max_output_scale, + min_output_scale) + if output_scale == 1.0: + return output + else: + return output_scale * output + (1.0 - output_scale) * prev_output + + for i, mod in enumerate(self.layers): output, attn_scores = mod( - output, + outputs[-1], pos_emb, attn_scores, src_mask=mask, src_key_padding_mask=src_key_padding_mask, ) - + # bypass this layer; the scale on `output` reaches a maximum of 0.5 which + # empirically seemed slightly better than 1. + output = apply_bypass(outputs[-1], output, + warmup, 0.1, 0.5) + # also apply bypass to twos and fours of layers. + if i > 0 and i % 2 == 0: + output = apply_bypass(outputs[-2], output, + warmup, 0.25, 1.0) + if i > 0 and i % 4 == 0: + output = apply_bypass(outputs[-4], output, + warmup, 0.25, 1.0) output = output * feature_mask + outputs.append(output) - return output + return outputs[-1] class DownsampledConformerEncoder(nn.Module):