From 03fe1ed20070de4914842ad5964ce163ae4ae724 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 17 Oct 2022 11:03:29 +0800 Subject: [PATCH] Make attention dims configurable, not embed_dim//2, trying 256. --- .../pruned_transducer_stateless7/conformer.py | 65 +++++++++++-------- .../ASR/pruned_transducer_stateless7/train.py | 23 ++++--- 2 files changed, 52 insertions(+), 36 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 2cbb9c570..61e1edc5a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -44,9 +44,10 @@ class Conformer(EncoderInterface): Args: num_features (int): Number of input features subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) - d_model (int): embedding dimension - nhead (int): number of head - dim_feedforward (int): feedforward dimention + d_model: (int,int): embedding dimension of 2 encoder stacks + attention_dim: (int,int): attention dimension of 2 encoder stacks + nhead (int, int): number of heads + dim_feedforward (int, int): feedforward dimention in 2 encoder stacks num_encoder_layers (int): number of encoder layers dropout (float): dropout rate cnn_module_kernel (int): Kernel size of convolution module @@ -60,6 +61,7 @@ class Conformer(EncoderInterface): subsampling_factor: int = 4, conformer_subsampling_factor: int = 4, d_model: Tuple[int] = (384, 384), + attention_dim: Tuple[int] = (256, 256), encoder_unmasked_dim: int = 256, nhead: Tuple[int] = (8, 8), feedforward_dim: Tuple[int] = (1536, 2048), @@ -92,6 +94,7 @@ class Conformer(EncoderInterface): encoder_layer1 = ConformerEncoderLayer( d_model[0], + attention_dim[0], nhead[0], feedforward_dim[0], dropout, @@ -110,6 +113,7 @@ class Conformer(EncoderInterface): ) encoder_layer2 = ConformerEncoderLayer( d_model[1], + attention_dim[1], nhead[1], feedforward_dim[1], dropout, @@ -248,19 +252,20 @@ class ConformerEncoderLayer(nn.Module): >>> out = encoder_layer(src, pos_emb) """ def __init__( - self, - d_model: int, - nhead: int, - feedforward_dim: int = 2048, - dropout: float = 0.1, - cnn_module_kernel: int = 31, + self, + d_model: int, + attention_dim: int, + nhead: int, + feedforward_dim: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, ) -> None: super(ConformerEncoderLayer, self).__init__() self.d_model = d_model self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0, + d_model, attention_dim, nhead, dropout=0.0, ) self.feed_forward1 = FeedforwardModule(d_model, @@ -807,6 +812,8 @@ class RelPositionMultiheadAttention(nn.Module): Args: embed_dim: total dimension of the model. + attention_dim: dimension in the attention module, may be less or more than embed_dim + but must be a multiple of num_heads. num_heads: parallel attention heads. dropout: a Dropout layer on attn_output_weights. Default: 0.0. @@ -819,19 +826,21 @@ class RelPositionMultiheadAttention(nn.Module): def __init__( self, embed_dim: int, + attention_dim: int, num_heads: int, dropout: float = 0.0, ) -> None: super(RelPositionMultiheadAttention, self).__init__() self.embed_dim = embed_dim + self.attention_dim = attention_dim self.num_heads = num_heads self.dropout = dropout - self.head_dim = embed_dim // (num_heads * 2) + self.head_dim = attention_dim // num_heads assert ( - self.head_dim * num_heads == self.embed_dim // 2 + self.head_dim * num_heads == attention_dim ), "embed_dim//2 must be divisible by num_heads" - self.in_proj = nn.Linear(embed_dim, 3 * embed_dim // 2, bias=True) + self.in_proj = nn.Linear(embed_dim, 3 * attention_dim, bias=True) # self.whiten_values is applied on the values in forward() self.whiten_values = Whiten(num_groups=num_heads, @@ -845,14 +854,14 @@ class RelPositionMultiheadAttention(nn.Module): grad_scale=0.025) - self.in_balancer = ActivationBalancer(3 * embed_dim // 2, + self.in_balancer = ActivationBalancer(3 * attention_dim, channel_dim=-1, max_abs=5.0) self.out_proj = ScaledLinear( - embed_dim // 2, embed_dim, bias=True, initial_scale=0.05 + attention_dim, embed_dim, bias=True, initial_scale=0.05 ) - self.in_proj2 = nn.Linear(embed_dim, embed_dim // 2, bias=False) - self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True, + self.in_proj2 = nn.Linear(embed_dim, attention_dim, bias=False) + self.out_proj2 = ScaledLinear(attention_dim, embed_dim, bias=True, initial_scale=0.05) # self.whiten_values2 is applied on the values in forward2() self.whiten_values2 = Whiten(num_groups=num_heads, @@ -914,7 +923,7 @@ class RelPositionMultiheadAttention(nn.Module): x, weights = self.multi_head_attention_forward( self.in_balancer(self.in_proj(x)), self.linear_pos(pos_emb), - self.embed_dim, + self.attention_dim, self.num_heads, self.in_proj.weight, self.in_proj.bias, @@ -965,7 +974,7 @@ class RelPositionMultiheadAttention(nn.Module): self, x: Tensor, pos: Tensor, - embed_dim: int, + attention_dim: int, num_heads: int, in_proj_weight: Tensor, in_proj_bias: Tensor, @@ -980,7 +989,7 @@ class RelPositionMultiheadAttention(nn.Module): Args: x_proj: the projected input, to be split into query, key, value. pos: head-specific biases arising from the positional embeddings. - embed_dim: total dimension of the model. + attention_dim: dimension inside attention mechanism num_heads: parallel attention heads. in_proj_weight, in_proj_bias: input projection weight and bias. dropout_p: probability of an element to be zeroed. @@ -994,8 +1003,8 @@ class RelPositionMultiheadAttention(nn.Module): Shape: Inputs: - - x: :math:`(L, N, 3 * E//2)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. Will be split into (query, key, value). + - x: :math:`(L, N, 3 * A)` where L is the target sequence length, N is the batch size, A is + the attention dimension. Will be split into (query, key, value). - pos: :math:`(N, 2*L-1, H)` or :math:`(1, 2*L-1, H)` where L is the sequence length, N is the batch size, and H is the number of heads. - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. @@ -1019,10 +1028,10 @@ class RelPositionMultiheadAttention(nn.Module): seq_len, bsz, _ = x.size() - head_dim = embed_dim // (num_heads * 2) + head_dim = attention_dim // num_heads assert ( - head_dim * num_heads == embed_dim // 2 - ), "embed_dim must be divisible by num_heads" + head_dim * num_heads == attention_dim + ), "attention_dim must be divisible by num_heads" scaling = float(head_dim) ** -0.5 @@ -1142,7 +1151,7 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = ( attn_output.transpose(0, 1) .contiguous() - .view(seq_len, bsz, embed_dim // 2) + .view(seq_len, bsz, attention_dim) ) attn_output = nn.functional.linear( attn_output, out_proj_weight, out_proj_bias @@ -1167,7 +1176,7 @@ class RelPositionMultiheadAttention(nn.Module): """ num_heads = self.num_heads (seq_len, bsz, embed_dim) = x.shape - head_dim = embed_dim // (num_heads * 2) + head_dim = self.attention_dim // num_heads # v: (tgt_len, bsz, embed_dim // 2) v = self.in_proj2(x) v = self.whiten_values2(v) # does nothing in the forward pass. @@ -1183,7 +1192,7 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = ( attn_output.transpose(0, 1) .contiguous() - .view(seq_len, bsz, embed_dim // 2) + .view(seq_len, bsz, self.attention_dim) ) # returned value is of shape (seq_len, bsz, embed_dim), like x. return self.out_proj2(attn_output) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 68b186dcd..8bdf9e40d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -114,8 +114,14 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-dims", type=str, default="384,384", - help="Attention dimension in 2, blocks of conformer encoder layers, comma separated, " - "and the output dim of the encoder", + help="Embedding dimension in the 2 blocks of conformer encoder layers, comma separated" + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="256,256", + help="Attention dimension in the 2 blocks of conformer encoder layers, comma separated" ) parser.add_argument( @@ -418,17 +424,18 @@ def get_params() -> AttributeDict: def get_encoder_model(params: AttributeDict) -> nn.Module: # TODO: We can add an option to switch between Conformer and Transformer - def to_int_list(s: str): - return list(map(int, s.split(','))) + def to_int_tuple(s: str): + return tuple(map(int, s.split(','))) encoder = Conformer( num_features=params.feature_dim, subsampling_factor=params.subsampling_factor, conformer_subsampling_factor=params.conformer_subsampling_factor, - d_model=to_int_list(params.encoder_dims), + d_model=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), encoder_unmasked_dim=params.encoder_unmasked_dim, - nhead=to_int_list(params.nhead), - feedforward_dim=to_int_list(params.feedforward_dims), - num_encoder_layers=to_int_list(params.num_encoder_layers), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), ) return encoder