From d34eafa6236251f9cb4f22a8c915897e668e7549 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 27 Sep 2022 15:47:58 +0800 Subject: [PATCH] Closer to working.. --- .../pruned_transducer_stateless7/conformer.py | 223 ++++++++++++++---- .../ASR/pruned_transducer_stateless7/train.py | 51 ++-- 2 files changed, 210 insertions(+), 64 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 6465d5a55..bc5d4a322 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -39,7 +39,7 @@ class Conformer(EncoderInterface): Args: num_features (int): Number of input features subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) - d_model (int): attention dimension, also the output dimension + d_model (int): (attention_dim1, attention_dim2, output_dim) nhead (int): number of head dim_feedforward (int): feedforward dimention num_encoder_layers (int): number of encoder layers @@ -53,13 +53,14 @@ class Conformer(EncoderInterface): self, num_features: int, subsampling_factor: int = 4, - d_model: int = 256, - nhead: int = 4, - dim_feedforward: int = 2048, - num_encoder_layers: int = 12, + conformer_subsampling_factor: int = 4, + d_model: Tuple[int] = (256, 384, 512), + nhead: Tuple[int] = (8, 8), + dim_feedforward: Tuple[int] = (1536, 2048), + num_encoder_layers: Tuple[int] = (12, 12), dropout: float = 0.1, layer_dropout: float = 0.075, - cnn_module_kernel: int = 31, + cnn_module_kernel: Tuple[int] = (31, 31), aux_layer_period: int = 3, ) -> None: super(Conformer, self).__init__() @@ -74,23 +75,47 @@ class Conformer(EncoderInterface): # That is, it does two things simultaneously: # (1) subsampling: T -> T//subsampling_factor # (2) embedding: num_features -> d_model - self.encoder_embed = Conv2dSubsampling(num_features, d_model) + self.encoder_embed = Conv2dSubsampling(num_features, d_model[0], + dropout=dropout) - self.encoder_pos = RelPositionalEncoding(d_model, dropout) - - encoder_layer = ConformerEncoderLayer( - d_model, - nhead, - dim_feedforward, + encoder_layer1 = ConformerEncoderLayer( + d_model[0], + nhead[0], + dim_feedforward[0], dropout, layer_dropout, - cnn_module_kernel, + cnn_module_kernel[0], ) - self.encoder = ConformerEncoder( - encoder_layer, - num_encoder_layers, - aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)), + self.encoder1 = ConformerEncoder( + encoder_layer1, + num_encoder_layers[0], + aux_layers=list(range(0, num_encoder_layers[0] - 1, aux_layer_period)), + dropout=dropout ) + encoder_layer2 = ConformerEncoderLayer( + d_model[1], + nhead[1], + dim_feedforward[1], + dropout, + layer_dropout, + cnn_module_kernel[1], + ) + self.encoder2 = DownsampledConformerEncoder( + ConformerEncoder( + encoder_layer2, + num_encoder_layers[1], + aux_layers=list(range(0, num_encoder_layers[1] - 1, aux_layer_period)), + dropout=dropout + ), + input_dim=d_model[0], + module_dim=d_model[1], + output_dim=d_model[1], + downsample=conformer_subsampling_factor, + ) + + self.out_proj = ScaledLinear( + d_model[0] + d_model[1], d_model[2], + bias=False) def forward( @@ -114,7 +139,7 @@ class Conformer(EncoderInterface): of frames in `embeddings` before padding. """ x = self.encoder_embed(x) - x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) with warnings.catch_warnings(): @@ -124,12 +149,21 @@ class Conformer(EncoderInterface): assert x.size(0) == lengths.max().item() mask = make_pad_mask(lengths) - x = self.encoder( - x, pos_emb, src_key_padding_mask=mask, warmup=warmup - ) # (T, N, C) + # x1: + x1, x_no_combine = self.encoder1( + x, src_key_padding_mask=mask, warmup=warmup + ) # (T, N, C) where C == d_model[0] + + x2 = self.encoder1( + x1, src_key_padding_mask=mask, warmup=warmup + ) # (T, N, C) where C == d_model[1] + + x = torch.cat((x1, x2), dim=2) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + x = self.out_proj(x) + return x, lengths @@ -288,8 +322,12 @@ class ConformerEncoder(nn.Module): >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = conformer_encoder(src, pos_emb) + >>> out = conformer_encoder(src) + + + Returns: (combined_output, output), + where `combined_output` has gone through the RandomCombiner module and `output` is just the + original output, in case you need to bypass the RandomCombiner module. """ def __init__( @@ -297,8 +335,13 @@ class ConformerEncoder(nn.Module): encoder_layer: nn.Module, num_layers: int, aux_layers: List[int], + dropout: float, ) -> None: super().__init__() + + self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, + dropout) + self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] ) @@ -318,16 +361,14 @@ class ConformerEncoder(nn.Module): def forward( self, src: Tensor, - pos_emb: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, warmup: float = 1.0, - ) -> Tensor: + ) -> Tuple[Tensor, Tensor]: r"""Pass the input through the encoder layers in turn. Args: src: the sequence to the encoder (required). - pos_emb: Positional embedding tensor (required). mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). @@ -338,7 +379,9 @@ class ConformerEncoder(nn.Module): src_key_padding_mask: (N, S). S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + Returns: (x, x_no_combine), both of shape (S, N, E) """ + pos_emb = self.encoder_pos(src) output = src outputs = [] @@ -356,11 +399,103 @@ class ConformerEncoder(nn.Module): if i in self.aux_layers: outputs.append(output) - output = self.combiner(outputs) + combined_output = self.combiner(outputs) - output = output + 0.0 * attn_scores.sum() # just ensure attn_scores is used in backprop + combined_output = combined_output + 0.0 * attn_scores.sum() # just ensure attn_scores is used in backprop - return output + return combined_output, output + + +class DownsampledConformerEncoder(nn.Module): + r""" + DownsampledConformerEncoder is a conformer encoder evaluated at a reduced frame rate, + after convolutional downsampling, and then upsampled again at the output + so that the output has the same shape as the input. + """ + def __init__(self, + encoder: nn.Module, + input_dim: int, + module_dim: int, + output_dim: int, + downsample: int): + super(DownsampledConformerEncoder, self).__init__() + + self.downsample = downsample + + # note: we'll pad manually. + self.downsample = nn.Conv1d( + input_dim, + module_dim, + kernel_size=downsample, + stride=downsample, + padding=0) + + self.encoder = encoder + + self.upsample = nn.ConvTranspose1d( + module_dim, + output_dim, + kernel_size=downsample, + stride=downsample, + padding=0) + + def forward(self, + src: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + warmup: float = 1.0, + ) -> Tuple[Tensor, Tensor]: + r"""Downsample, go through encoder, upsample. + + Args: + src: the sequence to the encoder (required). + mask: the mask for the src sequence (optional). CAUTION: we need to downsample + this, if we are to support it. Won't work correctly yet. + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + Returns: (x, x_no_combine), both of shape (S, N, E) + """ + (seq_len, batch_size, embedding_dim) = x.shape + ds = self.downsample + d_seq_len = (seq_len + ds - 1) // ds + x_orig = x + if seq_len != d_seq_len * ds: + # right-pad x + pad = seq_len - d_seq_len * ds + x = torch.nn.functional.pad(x, + (0, pad, 0, 0, 0, 0), + mode='replicate') + + if mask is not None: + mask = mask[::ds,::ds] + if src_key_padding_mask is not None: + src_key_padding_mask = src_key_padding_mask[::ds] + + x = x.permute(1, 2, 0) # (#batch, channels, time). + x = self.downsample(x) + x = x.permute(2, 0, 1) # (time, batch, channels) + + + x, _x_no_combine = self.encoder( + x, src_key_padding_mask=mask, warmup=warmup + ) # (T, N, C) + + x = x.permute(1, 2, 0) # (#batch, channels, time). + x = self.upsample(x) + x = x.permute(2, 0, 1) # (time, batch, channels) + + new_seq_len = x.shape[0] + assert new_seq_len >= seq_len + if new_seq_len > seq_len: + x = x[:seq_len] + + return x class RelPositionalEncoding(torch.nn.Module): @@ -379,7 +514,7 @@ class RelPositionalEncoding(torch.nn.Module): def __init__( self, d_model: int, dropout_rate: float, max_len: int = 5000 ) -> None: - """Construct an PositionalEncoding object.""" + """Construct a PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model self.dropout = torch.nn.Dropout(dropout_rate) @@ -391,7 +526,7 @@ class RelPositionalEncoding(torch.nn.Module): if self.pe is not None: # self.pe contains both positive and negative parts # the length of self.pe is 2 * input_len - 1 - if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.size(1) >= x.size(0) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device if self.pe.dtype != x.dtype or str(self.pe.device) != str( x.device @@ -401,9 +536,9 @@ class RelPositionalEncoding(torch.nn.Module): # Suppose `i` means to the position of query vecotr and `j` means the # position of key vector. We use position relative positions when keys # are to the left (i>j) and negative relative positions otherwise (i None: """ Args: @@ -998,6 +1134,7 @@ class Conv2dSubsampling(nn.Module): ) out_height = (((in_channels - 1) // 2 - 1) // 2) self.out = ScaledLinear(out_height * layer3_channels, out_channels) + self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -1017,6 +1154,7 @@ class Conv2dSubsampling(nn.Module): b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) + x = self.dropout(x) return x class RandomCombine(nn.Module): @@ -1251,14 +1389,13 @@ def _test_random_combine_main(): def _test_conformer_main(): feature_dim = 50 - c = Conformer(num_features=feature_dim, d_model=128, nhead=4) batch_size = 5 seq_len = 20 feature_dim = 50 # Just make sure the forward pass runs. c = Conformer( - num_features=feature_dim, d_model=128, nhead=4 + num_features=feature_dim, d_model=(64,96,128), nhead=(4,4) ) batch_size = 5 seq_len = 20 @@ -1271,8 +1408,6 @@ def _test_conformer_main(): f # to remove flake8 warnings - - if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index aa345dd84..42274ce5c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -91,30 +91,38 @@ LRSchedulerType = Union[ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", - type=int, - default=24, - help="Number of conformer encoder layers..", + type=str, + default="12,12", + help="Number of conformer encoder layers, comma separated.", ) parser.add_argument( - "--dim-feedforward", - type=int, - default=1536, - help="Feedforward dimension of the conformer encoder layer.", + "--feedforward-dims", + type=str, + default="1536,1536", + help="Feedforward dimension of the conformer encoder layers, comma separated.", ) parser.add_argument( "--nhead", - type=int, - default=8, - help="Number of attention heads in the conformer encoder layer.", + type=str, + default="8,8", + help="Number of attention heads in the conformer encoder layers.", ) parser.add_argument( - "--encoder-dim", + "--encoder-dims", + type=str, + default="320,512,512", + help="Attention dimension in 2, blocks of conformer encoder layers, comma separated, " + "and the output dim of the encoder", + ) + + parser.add_argument( + "--conformer-subsampling-factor", type=int, - default=384, - help="Attention dimension in the conformer encoder layer.", + default=4, + help="Subsampling factor for 2nd stack of encoder layers.", ) parser.add_argument( @@ -401,13 +409,16 @@ def get_params() -> AttributeDict: def get_encoder_model(params: AttributeDict) -> nn.Module: # TODO: We can add an option to switch between Conformer and Transformer + def to_int_list(s: str): + return list(map(int, s.split(','))) encoder = Conformer( num_features=params.feature_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.encoder_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, + subsampling_factor=params.subsamplng_factor, + conformer_subsampling_factor=params.conformer_subsamplng_factor, + d_model=to_int_list(params.encoder_dims), + nhead=to_int_list(params.nhead), + feedforward_dims=to_int_list(params.feedforward_dims), + num_encoder_layers=to_int_list(params.num_encoder_layers), ) return encoder @@ -424,7 +435,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - encoder_dim=params.encoder_dim, + encoder_dim=int(params.encoder_dims.split(',')[-1]), decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, @@ -441,7 +452,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: encoder=encoder, decoder=decoder, joiner=joiner, - encoder_dim=params.encoder_dim, + encoder_dim=int(params.encoder_dims.split(',')[-1]), decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size,