From d34eafa6236251f9cb4f22a8c915897e668e7549 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 27 Sep 2022 15:47:58 +0800 Subject: [PATCH 1/8] Closer to working.. --- .../pruned_transducer_stateless7/conformer.py | 223 ++++++++++++++---- .../ASR/pruned_transducer_stateless7/train.py | 51 ++-- 2 files changed, 210 insertions(+), 64 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 6465d5a55..bc5d4a322 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -39,7 +39,7 @@ class Conformer(EncoderInterface): Args: num_features (int): Number of input features subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) - d_model (int): attention dimension, also the output dimension + d_model (int): (attention_dim1, attention_dim2, output_dim) nhead (int): number of head dim_feedforward (int): feedforward dimention num_encoder_layers (int): number of encoder layers @@ -53,13 +53,14 @@ class Conformer(EncoderInterface): self, num_features: int, subsampling_factor: int = 4, - d_model: int = 256, - nhead: int = 4, - dim_feedforward: int = 2048, - num_encoder_layers: int = 12, + conformer_subsampling_factor: int = 4, + d_model: Tuple[int] = (256, 384, 512), + nhead: Tuple[int] = (8, 8), + dim_feedforward: Tuple[int] = (1536, 2048), + num_encoder_layers: Tuple[int] = (12, 12), dropout: float = 0.1, layer_dropout: float = 0.075, - cnn_module_kernel: int = 31, + cnn_module_kernel: Tuple[int] = (31, 31), aux_layer_period: int = 3, ) -> None: super(Conformer, self).__init__() @@ -74,23 +75,47 @@ class Conformer(EncoderInterface): # That is, it does two things simultaneously: # (1) subsampling: T -> T//subsampling_factor # (2) embedding: num_features -> d_model - self.encoder_embed = Conv2dSubsampling(num_features, d_model) + self.encoder_embed = Conv2dSubsampling(num_features, d_model[0], + dropout=dropout) - self.encoder_pos = RelPositionalEncoding(d_model, dropout) - - encoder_layer = ConformerEncoderLayer( - d_model, - nhead, - dim_feedforward, + encoder_layer1 = ConformerEncoderLayer( + d_model[0], + nhead[0], + dim_feedforward[0], dropout, layer_dropout, - cnn_module_kernel, + cnn_module_kernel[0], ) - self.encoder = ConformerEncoder( - encoder_layer, - num_encoder_layers, - aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)), + self.encoder1 = ConformerEncoder( + encoder_layer1, + num_encoder_layers[0], + aux_layers=list(range(0, num_encoder_layers[0] - 1, aux_layer_period)), + dropout=dropout ) + encoder_layer2 = ConformerEncoderLayer( + d_model[1], + nhead[1], + dim_feedforward[1], + dropout, + layer_dropout, + cnn_module_kernel[1], + ) + self.encoder2 = DownsampledConformerEncoder( + ConformerEncoder( + encoder_layer2, + num_encoder_layers[1], + aux_layers=list(range(0, num_encoder_layers[1] - 1, aux_layer_period)), + dropout=dropout + ), + input_dim=d_model[0], + module_dim=d_model[1], + output_dim=d_model[1], + downsample=conformer_subsampling_factor, + ) + + self.out_proj = ScaledLinear( + d_model[0] + d_model[1], d_model[2], + bias=False) def forward( @@ -114,7 +139,7 @@ class Conformer(EncoderInterface): of frames in `embeddings` before padding. """ x = self.encoder_embed(x) - x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) with warnings.catch_warnings(): @@ -124,12 +149,21 @@ class Conformer(EncoderInterface): assert x.size(0) == lengths.max().item() mask = make_pad_mask(lengths) - x = self.encoder( - x, pos_emb, src_key_padding_mask=mask, warmup=warmup - ) # (T, N, C) + # x1: + x1, x_no_combine = self.encoder1( + x, src_key_padding_mask=mask, warmup=warmup + ) # (T, N, C) where C == d_model[0] + + x2 = self.encoder1( + x1, src_key_padding_mask=mask, warmup=warmup + ) # (T, N, C) where C == d_model[1] + + x = torch.cat((x1, x2), dim=2) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + x = self.out_proj(x) + return x, lengths @@ -288,8 +322,12 @@ class ConformerEncoder(nn.Module): >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = conformer_encoder(src, pos_emb) + >>> out = conformer_encoder(src) + + + Returns: (combined_output, output), + where `combined_output` has gone through the RandomCombiner module and `output` is just the + original output, in case you need to bypass the RandomCombiner module. """ def __init__( @@ -297,8 +335,13 @@ class ConformerEncoder(nn.Module): encoder_layer: nn.Module, num_layers: int, aux_layers: List[int], + dropout: float, ) -> None: super().__init__() + + self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, + dropout) + self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] ) @@ -318,16 +361,14 @@ class ConformerEncoder(nn.Module): def forward( self, src: Tensor, - pos_emb: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, warmup: float = 1.0, - ) -> Tensor: + ) -> Tuple[Tensor, Tensor]: r"""Pass the input through the encoder layers in turn. Args: src: the sequence to the encoder (required). - pos_emb: Positional embedding tensor (required). mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). @@ -338,7 +379,9 @@ class ConformerEncoder(nn.Module): src_key_padding_mask: (N, S). S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + Returns: (x, x_no_combine), both of shape (S, N, E) """ + pos_emb = self.encoder_pos(src) output = src outputs = [] @@ -356,11 +399,103 @@ class ConformerEncoder(nn.Module): if i in self.aux_layers: outputs.append(output) - output = self.combiner(outputs) + combined_output = self.combiner(outputs) - output = output + 0.0 * attn_scores.sum() # just ensure attn_scores is used in backprop + combined_output = combined_output + 0.0 * attn_scores.sum() # just ensure attn_scores is used in backprop - return output + return combined_output, output + + +class DownsampledConformerEncoder(nn.Module): + r""" + DownsampledConformerEncoder is a conformer encoder evaluated at a reduced frame rate, + after convolutional downsampling, and then upsampled again at the output + so that the output has the same shape as the input. + """ + def __init__(self, + encoder: nn.Module, + input_dim: int, + module_dim: int, + output_dim: int, + downsample: int): + super(DownsampledConformerEncoder, self).__init__() + + self.downsample = downsample + + # note: we'll pad manually. + self.downsample = nn.Conv1d( + input_dim, + module_dim, + kernel_size=downsample, + stride=downsample, + padding=0) + + self.encoder = encoder + + self.upsample = nn.ConvTranspose1d( + module_dim, + output_dim, + kernel_size=downsample, + stride=downsample, + padding=0) + + def forward(self, + src: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + warmup: float = 1.0, + ) -> Tuple[Tensor, Tensor]: + r"""Downsample, go through encoder, upsample. + + Args: + src: the sequence to the encoder (required). + mask: the mask for the src sequence (optional). CAUTION: we need to downsample + this, if we are to support it. Won't work correctly yet. + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + Returns: (x, x_no_combine), both of shape (S, N, E) + """ + (seq_len, batch_size, embedding_dim) = x.shape + ds = self.downsample + d_seq_len = (seq_len + ds - 1) // ds + x_orig = x + if seq_len != d_seq_len * ds: + # right-pad x + pad = seq_len - d_seq_len * ds + x = torch.nn.functional.pad(x, + (0, pad, 0, 0, 0, 0), + mode='replicate') + + if mask is not None: + mask = mask[::ds,::ds] + if src_key_padding_mask is not None: + src_key_padding_mask = src_key_padding_mask[::ds] + + x = x.permute(1, 2, 0) # (#batch, channels, time). + x = self.downsample(x) + x = x.permute(2, 0, 1) # (time, batch, channels) + + + x, _x_no_combine = self.encoder( + x, src_key_padding_mask=mask, warmup=warmup + ) # (T, N, C) + + x = x.permute(1, 2, 0) # (#batch, channels, time). + x = self.upsample(x) + x = x.permute(2, 0, 1) # (time, batch, channels) + + new_seq_len = x.shape[0] + assert new_seq_len >= seq_len + if new_seq_len > seq_len: + x = x[:seq_len] + + return x class RelPositionalEncoding(torch.nn.Module): @@ -379,7 +514,7 @@ class RelPositionalEncoding(torch.nn.Module): def __init__( self, d_model: int, dropout_rate: float, max_len: int = 5000 ) -> None: - """Construct an PositionalEncoding object.""" + """Construct a PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model self.dropout = torch.nn.Dropout(dropout_rate) @@ -391,7 +526,7 @@ class RelPositionalEncoding(torch.nn.Module): if self.pe is not None: # self.pe contains both positive and negative parts # the length of self.pe is 2 * input_len - 1 - if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.size(1) >= x.size(0) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device if self.pe.dtype != x.dtype or str(self.pe.device) != str( x.device @@ -401,9 +536,9 @@ class RelPositionalEncoding(torch.nn.Module): # Suppose `i` means to the position of query vecotr and `j` means the # position of key vector. We use position relative positions when keys # are to the left (i>j) and negative relative positions otherwise (i None: """ Args: @@ -998,6 +1134,7 @@ class Conv2dSubsampling(nn.Module): ) out_height = (((in_channels - 1) // 2 - 1) // 2) self.out = ScaledLinear(out_height * layer3_channels, out_channels) + self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -1017,6 +1154,7 @@ class Conv2dSubsampling(nn.Module): b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) + x = self.dropout(x) return x class RandomCombine(nn.Module): @@ -1251,14 +1389,13 @@ def _test_random_combine_main(): def _test_conformer_main(): feature_dim = 50 - c = Conformer(num_features=feature_dim, d_model=128, nhead=4) batch_size = 5 seq_len = 20 feature_dim = 50 # Just make sure the forward pass runs. c = Conformer( - num_features=feature_dim, d_model=128, nhead=4 + num_features=feature_dim, d_model=(64,96,128), nhead=(4,4) ) batch_size = 5 seq_len = 20 @@ -1271,8 +1408,6 @@ def _test_conformer_main(): f # to remove flake8 warnings - - if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index aa345dd84..42274ce5c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -91,30 +91,38 @@ LRSchedulerType = Union[ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", - type=int, - default=24, - help="Number of conformer encoder layers..", + type=str, + default="12,12", + help="Number of conformer encoder layers, comma separated.", ) parser.add_argument( - "--dim-feedforward", - type=int, - default=1536, - help="Feedforward dimension of the conformer encoder layer.", + "--feedforward-dims", + type=str, + default="1536,1536", + help="Feedforward dimension of the conformer encoder layers, comma separated.", ) parser.add_argument( "--nhead", - type=int, - default=8, - help="Number of attention heads in the conformer encoder layer.", + type=str, + default="8,8", + help="Number of attention heads in the conformer encoder layers.", ) parser.add_argument( - "--encoder-dim", + "--encoder-dims", + type=str, + default="320,512,512", + help="Attention dimension in 2, blocks of conformer encoder layers, comma separated, " + "and the output dim of the encoder", + ) + + parser.add_argument( + "--conformer-subsampling-factor", type=int, - default=384, - help="Attention dimension in the conformer encoder layer.", + default=4, + help="Subsampling factor for 2nd stack of encoder layers.", ) parser.add_argument( @@ -401,13 +409,16 @@ def get_params() -> AttributeDict: def get_encoder_model(params: AttributeDict) -> nn.Module: # TODO: We can add an option to switch between Conformer and Transformer + def to_int_list(s: str): + return list(map(int, s.split(','))) encoder = Conformer( num_features=params.feature_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.encoder_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, + subsampling_factor=params.subsamplng_factor, + conformer_subsampling_factor=params.conformer_subsamplng_factor, + d_model=to_int_list(params.encoder_dims), + nhead=to_int_list(params.nhead), + feedforward_dims=to_int_list(params.feedforward_dims), + num_encoder_layers=to_int_list(params.num_encoder_layers), ) return encoder @@ -424,7 +435,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - encoder_dim=params.encoder_dim, + encoder_dim=int(params.encoder_dims.split(',')[-1]), decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, @@ -441,7 +452,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: encoder=encoder, decoder=decoder, joiner=joiner, - encoder_dim=params.encoder_dim, + encoder_dim=int(params.encoder_dims.split(',')[-1]), decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, From 01af88c2f6b364f33677ac3ff02c5931d9d85c7b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 27 Sep 2022 16:09:30 +0800 Subject: [PATCH 2/8] Various fixes --- .../pruned_transducer_stateless7/conformer.py | 71 ++++++++++--------- .../ASR/pruned_transducer_stateless7/train.py | 6 +- 2 files changed, 39 insertions(+), 38 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index bc5d4a322..1855e06ae 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -56,7 +56,7 @@ class Conformer(EncoderInterface): conformer_subsampling_factor: int = 4, d_model: Tuple[int] = (256, 384, 512), nhead: Tuple[int] = (8, 8), - dim_feedforward: Tuple[int] = (1536, 2048), + feedforward_dim: Tuple[int] = (1536, 2048), num_encoder_layers: Tuple[int] = (12, 12), dropout: float = 0.1, layer_dropout: float = 0.075, @@ -81,7 +81,7 @@ class Conformer(EncoderInterface): encoder_layer1 = ConformerEncoderLayer( d_model[0], nhead[0], - dim_feedforward[0], + feedforward_dim[0], dropout, layer_dropout, cnn_module_kernel[0], @@ -95,7 +95,7 @@ class Conformer(EncoderInterface): encoder_layer2 = ConformerEncoderLayer( d_model[1], nhead[1], - dim_feedforward[1], + feedforward_dim[1], dropout, layer_dropout, cnn_module_kernel[1], @@ -150,12 +150,12 @@ class Conformer(EncoderInterface): mask = make_pad_mask(lengths) # x1: - x1, x_no_combine = self.encoder1( + x1, x1_no_combine = self.encoder1( x, src_key_padding_mask=mask, warmup=warmup ) # (T, N, C) where C == d_model[0] - x2 = self.encoder1( - x1, src_key_padding_mask=mask, warmup=warmup + x2 = self.encoder2( + x1_no_combine, src_key_padding_mask=mask, warmup=warmup ) # (T, N, C) where C == d_model[1] x = torch.cat((x1, x2), dim=2) @@ -175,7 +175,7 @@ class ConformerEncoderLayer(nn.Module): Args: d_model: the number of expected features in the input (required). nhead: the number of heads in the multiheadattention models (required). - dim_feedforward: the dimension of the feedforward network model (default=2048). + feedforward_dim: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). cnn_module_kernel (int): Kernel size of convolution module. @@ -190,7 +190,7 @@ class ConformerEncoderLayer(nn.Module): self, d_model: int, nhead: int, - dim_feedforward: int = 2048, + feedforward_dim: int = 2048, dropout: float = 0.1, layer_dropout: float = 0.075, cnn_module_kernel: int = 31, @@ -206,22 +206,22 @@ class ConformerEncoderLayer(nn.Module): ) self.feed_forward = nn.Sequential( - nn.Linear(d_model, dim_feedforward), - ActivationBalancer(dim_feedforward, + nn.Linear(d_model, feedforward_dim), + ActivationBalancer(feedforward_dim, channel_dim=-1, max_abs=10.0), DoubleSwish(), nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model, + ScaledLinear(feedforward_dim, d_model, initial_scale=0.1), ) self.feed_forward_macaron = nn.Sequential( - nn.Linear(d_model, dim_feedforward), - ActivationBalancer(dim_feedforward, + nn.Linear(d_model, feedforward_dim), + ActivationBalancer(feedforward_dim, channel_dim=-1, max_abs=10.0), DoubleSwish(), nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model, + ScaledLinear(feedforward_dim, d_model, initial_scale=0.1), ) @@ -420,7 +420,7 @@ class DownsampledConformerEncoder(nn.Module): downsample: int): super(DownsampledConformerEncoder, self).__init__() - self.downsample = downsample + self.downsample_factor = downsample # note: we'll pad manually. self.downsample = nn.Conv1d( @@ -459,43 +459,44 @@ class DownsampledConformerEncoder(nn.Module): src_key_padding_mask: (N, S). S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number - Returns: (x, x_no_combine), both of shape (S, N, E) + Returns: output of shape (S, N, F) where F is the number of output features + (output_dim to constructor) """ - (seq_len, batch_size, embedding_dim) = x.shape - ds = self.downsample + (seq_len, batch_size, embedding_dim) = src.shape + ds = self.downsample_factor d_seq_len = (seq_len + ds - 1) // ds - x_orig = x + src_orig = src if seq_len != d_seq_len * ds: - # right-pad x - pad = seq_len - d_seq_len * ds - x = torch.nn.functional.pad(x, - (0, pad, 0, 0, 0, 0), - mode='replicate') + # right-pad src, repeating the last element. + pad = d_seq_len * ds - seq_len + src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2]) + src = torch.cat((src, src_extra), dim=0) + assert src.shape[0] == d_seq_len * ds if mask is not None: mask = mask[::ds,::ds] if src_key_padding_mask is not None: src_key_padding_mask = src_key_padding_mask[::ds] - x = x.permute(1, 2, 0) # (#batch, channels, time). - x = self.downsample(x) - x = x.permute(2, 0, 1) # (time, batch, channels) + src = src.permute(1, 2, 0) # (#batch, channels, time). + src = self.downsample(src) + src = src.permute(2, 0, 1) # (time, batch, channels) - x, _x_no_combine = self.encoder( - x, src_key_padding_mask=mask, warmup=warmup + src, _src_no_combine = self.encoder( + src, src_key_padding_mask=mask, warmup=warmup ) # (T, N, C) - x = x.permute(1, 2, 0) # (#batch, channels, time). - x = self.upsample(x) - x = x.permute(2, 0, 1) # (time, batch, channels) + src = src.permute(1, 2, 0) # (#batch, channels, time). + src = self.upsample(src) + src = src.permute(2, 0, 1) # (time, batch, channels) - new_seq_len = x.shape[0] + new_seq_len = src.shape[0] assert new_seq_len >= seq_len if new_seq_len > seq_len: - x = x[:seq_len] + src = src[:seq_len] - return x + return src class RelPositionalEncoding(torch.nn.Module): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 42274ce5c..4d87d3e73 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -413,11 +413,11 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: return list(map(int, s.split(','))) encoder = Conformer( num_features=params.feature_dim, - subsampling_factor=params.subsamplng_factor, - conformer_subsampling_factor=params.conformer_subsamplng_factor, + subsampling_factor=params.subsampling_factor, + conformer_subsampling_factor=params.conformer_subsampling_factor, d_model=to_int_list(params.encoder_dims), nhead=to_int_list(params.nhead), - feedforward_dims=to_int_list(params.feedforward_dims), + feedforward_dim=to_int_list(params.feedforward_dims), num_encoder_layers=to_int_list(params.num_encoder_layers), ) return encoder From 10a3061025a4a2db536a97ab27d6178613ec586f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 28 Sep 2022 13:49:11 +0800 Subject: [PATCH 3/8] Simplify downsampling and upsampling --- .../pruned_transducer_stateless7/conformer.py | 163 +++++++++++++----- 1 file changed, 121 insertions(+), 42 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 1855e06ae..312ffe0e6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -108,14 +108,12 @@ class Conformer(EncoderInterface): dropout=dropout ), input_dim=d_model[0], - module_dim=d_model[1], output_dim=d_model[1], downsample=conformer_subsampling_factor, ) - self.out_proj = ScaledLinear( - d_model[0] + d_model[1], d_model[2], - bias=False) + self.out_combiner = SimpleCombiner(d_model[0], + d_model[1]) def forward( @@ -158,12 +156,10 @@ class Conformer(EncoderInterface): x1_no_combine, src_key_padding_mask=mask, warmup=warmup ) # (T, N, C) where C == d_model[1] - x = torch.cat((x1, x2), dim=2) + x = self.out_combiner(x1, x2) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - x = self.out_proj(x) - return x, lengths @@ -415,29 +411,14 @@ class DownsampledConformerEncoder(nn.Module): def __init__(self, encoder: nn.Module, input_dim: int, - module_dim: int, output_dim: int, downsample: int): super(DownsampledConformerEncoder, self).__init__() - self.downsample_factor = downsample - - # note: we'll pad manually. - self.downsample = nn.Conv1d( - input_dim, - module_dim, - kernel_size=downsample, - stride=downsample, - padding=0) - + self.downsample = AttentionDownsample(input_dim, output_dim, downsample) self.encoder = encoder + self.upsample = SimpleUpsample(output_dim, downsample) - self.upsample = nn.ConvTranspose1d( - module_dim, - output_dim, - kernel_size=downsample, - stride=downsample, - padding=0) def forward(self, src: Tensor, @@ -462,10 +443,56 @@ class DownsampledConformerEncoder(nn.Module): Returns: output of shape (S, N, F) where F is the number of output features (output_dim to constructor) """ - (seq_len, batch_size, embedding_dim) = src.shape + + src_orig = src + src = self.downsample(src) ds = self.downsample_factor + if mask is not None: + mask = mask[::ds,::ds] + if src_key_padding_mask is not None: + src_key_padding_mask = src_key_padding_mask[::ds] + + src, _src_no_combine = self.encoder( + src, src_key_padding_mask=mask, warmup=warmup + ) + src = self.upsample(src) + + return src + + +class AttentionDownsample(torch.nn.Module): + """ + Does downsampling with attention, by weighted sum, and a projection.. + """ + def __init__(self, + in_channels: int, + out_channels: int, + downsample: int): + """ + Require out_channels > in_channels. + """ + super(AttentionDownsample, self).__init__() + assert out_channels > in_channels + self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5)) + + # fill in the extra dimensions with a projection of the input + self.extra_proj = nn.Linear(in_channels * downsample, + out_channels - in_channels, + bias=False) + self.downsample = downsample + + def forward(self, + src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, in_channels) + Returns a tensor of shape + ( (seq_len+downsample-1)//downsample, batch_size, out_channels) + """ + (seq_len, batch_size, in_channels) = src.shape + ds = self.downsample d_seq_len = (seq_len + ds - 1) // ds src_orig = src + # Pad to an exact multiple of self.downsample if seq_len != d_seq_len * ds: # right-pad src, repeating the last element. pad = d_seq_len * ds - seq_len @@ -473,31 +500,83 @@ class DownsampledConformerEncoder(nn.Module): src = torch.cat((src, src_extra), dim=0) assert src.shape[0] == d_seq_len * ds - if mask is not None: - mask = mask[::ds,::ds] - if src_key_padding_mask is not None: - src_key_padding_mask = src_key_padding_mask[::ds] + src = src.reshape(d_seq_len, ds, batch_size, in_channels) + scores = (src * self.query).sum(dim=-1, keepdim=True) + weights = scores.softmax(dim=1) - src = src.permute(1, 2, 0) # (#batch, channels, time). - src = self.downsample(src) - src = src.permute(2, 0, 1) # (time, batch, channels) + # ans1 is the first `in_channels` channels of the output + ans1 = (src * weights).sum(dim=1) + src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels) + ans2 = self.extra_proj(src) - src, _src_no_combine = self.encoder( - src, src_key_padding_mask=mask, warmup=warmup - ) # (T, N, C) + ans = torch.cat((ans1, ans2), dim=2) + return ans - src = src.permute(1, 2, 0) # (#batch, channels, time). - src = self.upsample(src) - src = src.permute(2, 0, 1) # (time, batch, channels) - new_seq_len = src.shape[0] - assert new_seq_len >= seq_len - if new_seq_len > seq_len: - src = src[:seq_len] +class SimpleUpsample(torch.nn.Module): + """ + A very simple form of upsampling that mostly just repeats the input, but + also adds a position-specific bias. + """ + def __init__(self, + num_channels: int, + upsample: int): + super(SimpleUpsample, self).__init__() + self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01) + def forward(self, + src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, num_channels) + Returns a tensor of shape + ( (seq_len*upsample), batch_size, num_channels) + """ + upsample = self.bias.shape[0] + (seq_len, batch_size, num_channels) = src.shape + src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) + src = src + self.bias.unsqueeze(1) + src = src.reshape(seq_len * upsample, batch_size, num_channels) return src +class SimpleCombiner(torch.nn.Module): + """ + A very simple way of combining 2 vectors of 2 different dims, via a + learned weighted combination in the shared part of the dim. + + """ + def __init__(self, + dim1: int, + dim2: int): + super(SimpleCombiner, self).__init__() + assert dim2 > dim1 + self.to_weight1 = nn.Parameter(torch.randn(dim1) * 0.01) + self.to_weight2 = nn.Parameter(torch.randn(dim2) * 0.01) + + + def forward(self, + src1: Tensor, + src2: Tensor) -> Tensor: + """ + src1: (*, dim1) + src2: (*, dim2) + + Returns: a tensor of shape (*, dim2) + """ + assert src1.shape[:-1] == src2.shape[:-1] + dim1 = src1.shape[-1] + dim2 = src2.shape[-1] + + weight1 = (src1 * self.to_weight1).sum(dim=-1, keepdim=True) + weight2 = (src2 * self.to_weight2).sum(dim=-1, keepdim=True) + weight = (weight1 + weight2).sigmoid() + + src2_part1 = src2[...,:dim1] + part1 = src1 * weight + src2_part1 * (1.0 - weight) + part2 = src2[...,dim1:] + return torch.cat((part1, part2), dim=-1) + + class RelPositionalEncoding(torch.nn.Module): """Relative positional encoding module. From 1005ff35bad1881396d50590bd643c5c26448d5b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 28 Sep 2022 13:57:26 +0800 Subject: [PATCH 4/8] Fix w.r.t. uneven upsampling --- egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 312ffe0e6..79a6ebefb 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -456,6 +456,8 @@ class DownsampledConformerEncoder(nn.Module): src, src_key_padding_mask=mask, warmup=warmup ) src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[:src_orig.shape[0]] return src From df795912ed44e37087e892e691a1cb4e312090b2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 28 Sep 2022 20:56:40 +0800 Subject: [PATCH 5/8] Try to reproduce baseline but with current code with 2 encoder stacks, as a baseline --- .../pruned_transducer_stateless7/conformer.py | 17 ++++++++++------- .../ASR/pruned_transducer_stateless7/train.py | 4 ++-- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 79a6ebefb..9facae5ce 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -478,9 +478,12 @@ class AttentionDownsample(torch.nn.Module): self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5)) # fill in the extra dimensions with a projection of the input - self.extra_proj = nn.Linear(in_channels * downsample, - out_channels - in_channels, - bias=False) + if out_channels > in_channels: + self.extra_proj = nn.Linear(in_channels * downsample, + out_channels - in_channels, + bias=False) + else: + self.extra_proj = None self.downsample = downsample def forward(self, @@ -509,10 +512,10 @@ class AttentionDownsample(torch.nn.Module): # ans1 is the first `in_channels` channels of the output ans1 = (src * weights).sum(dim=1) src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels) - ans2 = self.extra_proj(src) - - ans = torch.cat((ans1, ans2), dim=2) + if self.extra_proj is not None: + ans2 = self.extra_proj(src) + ans = torch.cat((ans1, ans2), dim=2) return ans @@ -551,7 +554,7 @@ class SimpleCombiner(torch.nn.Module): dim1: int, dim2: int): super(SimpleCombiner, self).__init__() - assert dim2 > dim1 + assert dim2 >= dim1 self.to_weight1 = nn.Parameter(torch.randn(dim1) * 0.01) self.to_weight2 = nn.Parameter(torch.randn(dim2) * 0.01) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 4d87d3e73..7d9ec647e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -113,7 +113,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--encoder-dims", type=str, - default="320,512,512", + default="384,384", help="Attention dimension in 2, blocks of conformer encoder layers, comma separated, " "and the output dim of the encoder", ) @@ -121,7 +121,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--conformer-subsampling-factor", type=int, - default=4, + default=1, help="Subsampling factor for 2nd stack of encoder layers.", ) From e5666628bd655901402517c7cd163d4b258bc60e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 28 Sep 2022 20:58:34 +0800 Subject: [PATCH 6/8] Bug fix --- egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 9facae5ce..bd0e625f0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -474,7 +474,6 @@ class AttentionDownsample(torch.nn.Module): Require out_channels > in_channels. """ super(AttentionDownsample, self).__init__() - assert out_channels > in_channels self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5)) # fill in the extra dimensions with a projection of the input From 14a2603ada9b98420d8716baba12360e1d9c1ec0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 28 Sep 2022 20:59:24 +0800 Subject: [PATCH 7/8] Bug fix --- egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index bd0e625f0..7bda58669 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -509,12 +509,12 @@ class AttentionDownsample(torch.nn.Module): weights = scores.softmax(dim=1) # ans1 is the first `in_channels` channels of the output - ans1 = (src * weights).sum(dim=1) + ans = (src * weights).sum(dim=1) src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels) if self.extra_proj is not None: ans2 = self.extra_proj(src) - ans = torch.cat((ans1, ans2), dim=2) + ans = torch.cat((ans, ans2), dim=2) return ans From d6ef1bec5f6d8693939b0f88807b2b28488b62d2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 28 Sep 2022 21:10:13 +0800 Subject: [PATCH 8/8] Change subsamplling factor from 1 to 2 --- egs/librispeech/ASR/pruned_transducer_stateless7/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 7d9ec647e..f35b0e7af 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -121,7 +121,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--conformer-subsampling-factor", type=int, - default=1, + default=2, help="Subsampling factor for 2nd stack of encoder layers.", )