diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 2edfe376e..ffaa6660f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -41,7 +41,7 @@ class Conformer(EncoderInterface): Args: num_features (int): Number of input features subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) - d_model (int): attention dimension, also the output dimension + d_model (int): (attention_dim1, attention_dim2, output_dim) nhead (int): number of head dim_feedforward (int): feedforward dimention num_encoder_layers (int): number of encoder layers @@ -55,13 +55,13 @@ class Conformer(EncoderInterface): self, num_features: int, subsampling_factor: int = 4, - d_model: int = 256, - nhead: int = 4, - dim_feedforward: int = 2048, - num_encoder_layers: int = 12, + conformer_subsampling_factor: int = 4, + d_model: Tuple[int] = (256, 384, 512), + nhead: Tuple[int] = (8, 8), + feedforward_dim: Tuple[int] = (1536, 2048), + num_encoder_layers: Tuple[int] = (12, 12), dropout: float = 0.1, - layer_dropout: float = 0.25, - cnn_module_kernel: int = 31, + cnn_module_kernel: Tuple[int] = (31, 31), ) -> None: super(Conformer, self).__init__() @@ -75,22 +75,42 @@ class Conformer(EncoderInterface): # That is, it does two things simultaneously: # (1) subsampling: T -> T//subsampling_factor # (2) embedding: num_features -> d_model - self.encoder_embed = Conv2dSubsampling(num_features, d_model) + self.encoder_embed = Conv2dSubsampling(num_features, d_model[0], + dropout=dropout) - self.encoder_pos = RelPositionalEncoding(d_model, dropout) - - encoder_layer = ConformerEncoderLayer( - d_model, - nhead, - dim_feedforward, + encoder_layer1 = ConformerEncoderLayer( + d_model[0], + nhead[0], + feedforward_dim[0], dropout, - cnn_module_kernel, + cnn_module_kernel[0], + ) - self.encoder = ConformerEncoder( - encoder_layer, - num_encoder_layers, - layer_dropout=layer_dropout, + self.encoder1 = ConformerEncoder( + encoder_layer1, + num_encoder_layers[0], + dropout, ) + encoder_layer2 = ConformerEncoderLayer( + d_model[1], + nhead[1], + feedforward_dim[1], + dropout, + cnn_module_kernel[1], + ) + self.encoder2 = DownsampledConformerEncoder( + ConformerEncoder( + encoder_layer2, + num_encoder_layers[1], + dropout, + ), + input_dim=d_model[0], + output_dim=d_model[1], + downsample=conformer_subsampling_factor, + ) + + self.out_combiner = SimpleCombiner(d_model[0], + d_model[1]) def forward( @@ -110,7 +130,7 @@ class Conformer(EncoderInterface): of frames in `embeddings` before padding. """ x = self.encoder_embed(x) - x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) with warnings.catch_warnings(): @@ -120,9 +140,17 @@ class Conformer(EncoderInterface): assert x.size(0) == lengths.max().item() mask = make_pad_mask(lengths) - x = self.encoder( - x, pos_emb, src_key_padding_mask=mask, - ) # (T, N, C) + # x1: + x1 = self.encoder1( + x, src_key_padding_mask=mask, + ) # (T, N, C) where C == d_model[0] + + x2 = self.encoder2( + x1, src_key_padding_mask=mask, + ) # (T, N, C) where C == d_model[1] + + x = self.out_combiner(x1, x2) + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) @@ -137,7 +165,7 @@ class ConformerEncoderLayer(nn.Module): Args: d_model: the number of expected features in the input (required). nhead: the number of heads in the multiheadattention models (required). - dim_feedforward: the dimension of the feedforward network model (default=2048). + feedforward_dim: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). cnn_module_kernel (int): Kernel size of convolution module. @@ -151,7 +179,7 @@ class ConformerEncoderLayer(nn.Module): self, d_model: int, nhead: int, - dim_feedforward: int = 2048, + feedforward_dim: int = 2048, dropout: float = 0.1, cnn_module_kernel: int = 31, ) -> None: @@ -164,22 +192,22 @@ class ConformerEncoderLayer(nn.Module): ) self.feed_forward = nn.Sequential( - nn.Linear(d_model, dim_feedforward), - ActivationBalancer(dim_feedforward, + nn.Linear(d_model, feedforward_dim), + ActivationBalancer(feedforward_dim, channel_dim=-1, max_abs=10.0), DoubleSwish(), nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model, + ScaledLinear(feedforward_dim, d_model, initial_scale=0.01), ) self.feed_forward_macaron = nn.Sequential( - nn.Linear(d_model, dim_feedforward), - ActivationBalancer(dim_feedforward, + nn.Linear(d_model, feedforward_dim), + ActivationBalancer(feedforward_dim, channel_dim=-1, max_abs=10.0), DoubleSwish(), nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model, + ScaledLinear(feedforward_dim, d_model, initial_scale=0.01), ) @@ -261,22 +289,23 @@ class ConformerEncoder(nn.Module): >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = conformer_encoder(src, pos_emb) + >>> out = conformer_encoder(src) + + + Returns: (combined_output, output), + where `combined_output` has gone through the RandomCombiner module and `output` is just the + original output, in case you need to bypass the RandomCombiner module. """ def __init__( self, encoder_layer: nn.Module, num_layers: int, - layer_dropout: float = 0.25 + dropout: float ) -> None: super().__init__() - assert 0 < layer_dropout < 0.5 - # `count` tracks how many times the forward function has been called - # since we initialized the model (it is not written to disk or read when - # we resume training). It is used for random seeding for layer dropping. - self.count = 0 - self.layer_dropout = layer_dropout + + self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, + dropout) self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] @@ -287,19 +316,16 @@ class ConformerEncoder(nn.Module): num_channels = encoder_layer.norm_final.num_channels - def forward( self, src: Tensor, - pos_emb: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: + ) -> Tuple[Tensor, Tensor]: r"""Pass the input through the encoder layers in turn. Args: src: the sequence to the encoder (required). - pos_emb: Positional embedding tensor (required). mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). @@ -310,7 +336,9 @@ class ConformerEncoder(nn.Module): src_key_padding_mask: (N, S). S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + Returns: (x, x_no_combine), both of shape (S, N, E) """ + pos_emb = self.encoder_pos(src) output = src outputs = [] @@ -337,7 +365,6 @@ class ConformerEncoder(nn.Module): frame_mask = torch.logical_or(frame_mask, torch.rand_like(src[:,:1,:1]) < 0.1) - feature_mask[..., feature_unmasked_dim:] *= frame_mask @@ -364,11 +391,190 @@ class ConformerEncoder(nn.Module): src_mask=mask, src_key_padding_mask=src_key_padding_mask, ) + output = output * feature_mask return output +class DownsampledConformerEncoder(nn.Module): + r""" + DownsampledConformerEncoder is a conformer encoder evaluated at a reduced frame rate, + after convolutional downsampling, and then upsampled again at the output + so that the output has the same shape as the input. + """ + def __init__(self, + encoder: nn.Module, + input_dim: int, + output_dim: int, + downsample: int): + super(DownsampledConformerEncoder, self).__init__() + self.downsample_factor = downsample + self.downsample = AttentionDownsample(input_dim, output_dim, downsample) + self.encoder = encoder + self.upsample = SimpleUpsample(output_dim, downsample) + + + def forward(self, + src: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + r"""Downsample, go through encoder, upsample. + + Args: + src: the sequence to the encoder (required). + mask: the mask for the src sequence (optional). CAUTION: we need to downsample + this, if we are to support it. Won't work correctly yet. + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + Returns: output of shape (S, N, F) where F is the number of output features + (output_dim to constructor) + """ + src_orig = src + src = self.downsample(src) + ds = self.downsample_factor + if mask is not None: + mask = mask[::ds,::ds] + if src_key_padding_mask is not None: + src_key_padding_mask = src_key_padding_mask[::ds] + + src = self.encoder( + src, src_key_padding_mask=mask, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[:src_orig.shape[0]] + + return src + + +class AttentionDownsample(torch.nn.Module): + """ + Does downsampling with attention, by weighted sum, and a projection.. + """ + def __init__(self, + in_channels: int, + out_channels: int, + downsample: int): + """ + Require out_channels > in_channels. + """ + super(AttentionDownsample, self).__init__() + self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5)) + + # fill in the extra dimensions with a projection of the input + if out_channels > in_channels: + self.extra_proj = nn.Linear(in_channels * downsample, + out_channels - in_channels, + bias=False) + else: + self.extra_proj = None + self.downsample = downsample + + def forward(self, + src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, in_channels) + Returns a tensor of shape + ( (seq_len+downsample-1)//downsample, batch_size, out_channels) + """ + (seq_len, batch_size, in_channels) = src.shape + ds = self.downsample + d_seq_len = (seq_len + ds - 1) // ds + src_orig = src + # Pad to an exact multiple of self.downsample + if seq_len != d_seq_len * ds: + # right-pad src, repeating the last element. + pad = d_seq_len * ds - seq_len + src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2]) + src = torch.cat((src, src_extra), dim=0) + assert src.shape[0] == d_seq_len * ds + + src = src.reshape(d_seq_len, ds, batch_size, in_channels) + scores = (src * self.query).sum(dim=-1, keepdim=True) + weights = scores.softmax(dim=1) + + # ans1 is the first `in_channels` channels of the output + ans = (src * weights).sum(dim=1) + src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels) + + if self.extra_proj is not None: + ans2 = self.extra_proj(src) + ans = torch.cat((ans, ans2), dim=2) + return ans + + +class SimpleUpsample(torch.nn.Module): + """ + A very simple form of upsampling that mostly just repeats the input, but + also adds a position-specific bias. + """ + def __init__(self, + num_channels: int, + upsample: int): + super(SimpleUpsample, self).__init__() + self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01) + + def forward(self, + src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, num_channels) + Returns a tensor of shape + ( (seq_len*upsample), batch_size, num_channels) + """ + upsample = self.bias.shape[0] + (seq_len, batch_size, num_channels) = src.shape + src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) + src = src + self.bias.unsqueeze(1) + src = src.reshape(seq_len * upsample, batch_size, num_channels) + return src + +class SimpleCombiner(torch.nn.Module): + """ + A very simple way of combining 2 vectors of 2 different dims, via a + learned weighted combination in the shared part of the dim. + + """ + def __init__(self, + dim1: int, + dim2: int): + super(SimpleCombiner, self).__init__() + assert dim2 >= dim1 + self.to_weight1 = nn.Parameter(torch.randn(dim1) * 0.01) + self.to_weight2 = nn.Parameter(torch.randn(dim2) * 0.01) + + + def forward(self, + src1: Tensor, + src2: Tensor) -> Tensor: + """ + src1: (*, dim1) + src2: (*, dim2) + + Returns: a tensor of shape (*, dim2) + """ + assert src1.shape[:-1] == src2.shape[:-1] + dim1 = src1.shape[-1] + dim2 = src2.shape[-1] + + weight1 = (src1 * self.to_weight1).sum(dim=-1, keepdim=True) + weight2 = (src2 * self.to_weight2).sum(dim=-1, keepdim=True) + weight = (weight1 + weight2).sigmoid() + + src2_part1 = src2[...,:dim1] + part1 = src1 * weight + src2_part1 * (1.0 - weight) + part2 = src2[...,dim1:] + return torch.cat((part1, part2), dim=-1) + + + class RelPositionalEncoding(torch.nn.Module): """Relative positional encoding module. @@ -385,7 +591,7 @@ class RelPositionalEncoding(torch.nn.Module): def __init__( self, d_model: int, dropout_rate: float, max_len: int = 5000 ) -> None: - """Construct an PositionalEncoding object.""" + """Construct a PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model self.dropout = torch.nn.Dropout(dropout_rate) @@ -397,7 +603,7 @@ class RelPositionalEncoding(torch.nn.Module): if self.pe is not None: # self.pe contains both positive and negative parts # the length of self.pe is 2 * input_len - 1 - if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.size(1) >= x.size(0) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device if self.pe.dtype != x.dtype or str(self.pe.device) != str( x.device @@ -407,9 +613,9 @@ class RelPositionalEncoding(torch.nn.Module): # Suppose `i` means to the position of query vecotr and `j` means the # position of key vector. We use position relative positions when keys # are to the left (i>j) and negative relative positions otherwise (i None: """ Args: @@ -1012,6 +1219,7 @@ class Conv2dSubsampling(nn.Module): ) out_height = (((in_channels - 1) // 2 - 1) // 2) self.out = ScaledLinear(out_height * layer3_channels, out_channels) + self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -1031,6 +1239,7 @@ class Conv2dSubsampling(nn.Module): b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) + x = self.dropout(x) return x class AttentionCombine(nn.Module): @@ -1166,14 +1375,13 @@ def _test_random_combine(): def _test_conformer_main(): feature_dim = 50 - c = Conformer(num_features=feature_dim, d_model=128, nhead=4) batch_size = 5 seq_len = 20 feature_dim = 50 # Just make sure the forward pass runs. c = Conformer( - num_features=feature_dim, d_model=128, nhead=4 + num_features=feature_dim, d_model=(64,96,128), nhead=(4,4) ) batch_size = 5 seq_len = 20 @@ -1191,8 +1399,6 @@ def _test_conformer_main(): f # to remove flake8 warnings - - if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index d3680e75e..fda28302f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -91,30 +91,38 @@ LRSchedulerType = Union[ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", - type=int, - default=24, - help="Number of conformer encoder layers..", + type=str, + default="12,12", + help="Number of conformer encoder layers, comma separated.", ) parser.add_argument( - "--dim-feedforward", - type=int, - default=1536, - help="Feedforward dimension of the conformer encoder layer.", + "--feedforward-dims", + type=str, + default="1536,1536", + help="Feedforward dimension of the conformer encoder layers, comma separated.", ) parser.add_argument( "--nhead", - type=int, - default=8, - help="Number of attention heads in the conformer encoder layer.", + type=str, + default="8,8", + help="Number of attention heads in the conformer encoder layers.", ) parser.add_argument( - "--encoder-dim", + "--encoder-dims", + type=str, + default="384,384", + help="Attention dimension in 2, blocks of conformer encoder layers, comma separated, " + "and the output dim of the encoder", + ) + + parser.add_argument( + "--conformer-subsampling-factor", type=int, - default=384, - help="Attention dimension in the conformer encoder layer.", + default=2, + help="Subsampling factor for 2nd stack of encoder layers.", ) parser.add_argument( @@ -401,13 +409,16 @@ def get_params() -> AttributeDict: def get_encoder_model(params: AttributeDict) -> nn.Module: # TODO: We can add an option to switch between Conformer and Transformer + def to_int_list(s: str): + return list(map(int, s.split(','))) encoder = Conformer( num_features=params.feature_dim, subsampling_factor=params.subsampling_factor, - d_model=params.encoder_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, + conformer_subsampling_factor=params.conformer_subsampling_factor, + d_model=to_int_list(params.encoder_dims), + nhead=to_int_list(params.nhead), + feedforward_dim=to_int_list(params.feedforward_dims), + num_encoder_layers=to_int_list(params.num_encoder_layers), ) return encoder @@ -424,7 +435,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - encoder_dim=params.encoder_dim, + encoder_dim=int(params.encoder_dims.split(',')[-1]), decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, @@ -441,7 +452,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: encoder=encoder, decoder=decoder, joiner=joiner, - encoder_dim=params.encoder_dim, + encoder_dim=int(params.encoder_dims.split(',')[-1]), decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size,