From 12323f2fbf77db12aa86d7a6f314852fa70a7cda Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 10 Oct 2022 15:27:18 +0800 Subject: [PATCH] Refactor RelPosMultiheadAttention to have 2nd forward function and introduce more modules in conformer encoder layer --- .../pruned_transducer_stateless7/conformer.py | 150 ++++++++++++------ .../ASR/pruned_transducer_stateless7/train.py | 2 +- 2 files changed, 104 insertions(+), 48 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 9a78c6838..c38293788 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -265,28 +265,24 @@ class ConformerEncoderLayer(nn.Module): d_model, nhead, dropout=dropout, ) - self.feed_forward = nn.Sequential( - nn.Linear(d_model, feedforward_dim), - ActivationBalancer(feedforward_dim, - channel_dim=-1, max_abs=10.0), - DoubleSwish(), - nn.Dropout(dropout), - ScaledLinear(feedforward_dim, d_model, - initial_scale=0.01), - ) + self.feed_forward1 = FeedforwardModule(d_model, + feedforward_dim, + dropout) - self.feed_forward_macaron = nn.Sequential( - nn.Linear(d_model, feedforward_dim), - ActivationBalancer(feedforward_dim, - channel_dim=-1, max_abs=10.0), - DoubleSwish(), - nn.Dropout(dropout), - ScaledLinear(feedforward_dim, d_model, - initial_scale=0.01), - ) + self.feed_forward2 = FeedforwardModule(d_model, + feedforward_dim, + dropout) - self.conv_module = ConvolutionModule(d_model, - cnn_module_kernel) + self.feed_forward3 = FeedforwardModule(d_model, + feedforward_dim, + dropout) + + + self.conv_module1 = ConvolutionModule(d_model, + cnn_module_kernel) + + self.conv_module2 = ConvolutionModule(d_model, + cnn_module_kernel) self.norm_final = BasicNorm(d_model) @@ -330,10 +326,10 @@ class ConformerEncoderLayer(nn.Module): src_orig = src # macaron style feed forward module - src = src + self.feed_forward_macaron(src) + src = src + self.feed_forward1(src) # multi-headed self-attention module - src_att, _, attn_scores_out = self.self_attn( + src_att, attn_weights, attn_scores_out = self.self_attn( src, pos_emb=pos_emb, attn_scores_in=attn_scores_in, @@ -343,11 +339,16 @@ class ConformerEncoderLayer(nn.Module): src = src + src_att # convolution module - src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) + src = src + self.conv_module1(src, src_key_padding_mask=src_key_padding_mask) - # feed forward module - src = src + self.feed_forward(src) + src = src + self.feed_forward2(src) + + src = src + self.self_attn.forward2(src, attn_weights) + + src = src + self.conv_module2(src, src_key_padding_mask=src_key_padding_mask) + + src = src + self.feed_forward3(src) src = self.norm_final(self.balancer(src)) @@ -846,6 +847,10 @@ class RelPositionMultiheadAttention(nn.Module): embed_dim // 2, embed_dim, bias=True, initial_scale=0.05 ) + self.in_proj2 = nn.Linear(embed_dim, embed_dim // 2, bias=False) + self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True, + initial_scale=0.05) + self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads)) self.attn_scores_proj_out = nn.Parameter(torch.zeros(num_heads, num_heads)) @@ -867,9 +872,8 @@ class RelPositionMultiheadAttention(nn.Module): pos_emb: Tensor, attn_scores_in: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, attn_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + ) -> Tuple[Tensor, Tensor, Tensor]: r""" Args: x: input to be projected to query, key, value @@ -879,7 +883,6 @@ class RelPositionMultiheadAttention(nn.Module): the corresponding value on the attention layer will be ignored. When given a byte mask and a value is non-zero, the corresponding value on the attention layer will be ignored - need_weights: output attn_output_weights. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. @@ -902,11 +905,14 @@ class RelPositionMultiheadAttention(nn.Module): is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight. - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_output_weights: :math:`(N, L, S)` where N is the batch size, - L is the target sequence length, S is the source sequence length. + - Returns: (attn_output, attn_weights, attn_scores) + + - attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size, + E is the embedding dimension. + - attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads + and S is the sequence length. + - attn_scores: :math:`(N, S, S, H)`, these are the attn weights + before softmax. """ x, weights, scores = self.multi_head_attention_forward( self.in_balancer(self.in_proj(x)), @@ -921,7 +927,6 @@ class RelPositionMultiheadAttention(nn.Module): self.out_proj.bias, training=self.training, key_padding_mask=key_padding_mask, - need_weights=need_weights, attn_mask=attn_mask, ) if attn_scores_in is not None: @@ -973,7 +978,6 @@ class RelPositionMultiheadAttention(nn.Module): out_proj_bias: Tensor, training: bool = True, key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, attn_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: r""" @@ -989,7 +993,6 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask: if provided, specified padding elements in the key will be ignored by the attention. This is an binary mask. When the value is True, the corresponding value on the attention layer will be filled with -inf. - need_weights: output attn_output_weights. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. @@ -1017,9 +1020,11 @@ class RelPositionMultiheadAttention(nn.Module): Outputs: - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_output_weights: :math:`(N, L, S)` where N is the batch size, - L is the target sequence length, S is the source sequence length. + E is the embedding dimension. + - attn_weights: :math:`(N * H, S, S)` where N is the batch size, + H is the num-heads, S is the sequence length. + - attn_scores: :math:`(N, S, S, H)` where N is the batch size, + S is the sequence length and H is the num-heads. """ tgt_len, bsz, _ = x.size() @@ -1182,14 +1187,65 @@ class RelPositionMultiheadAttention(nn.Module): attn_output, out_proj_weight, out_proj_bias ) - if need_weights: - # average attention weights over heads - attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, src_len - ) - return attn_output, attn_output_weights.sum(dim=1) / num_heads, attn_scores_out - else: - return attn_output, None, attn_scores_out + return attn_output, attn_output_weights, attn_scores_out + + + def forward2( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """ + Second forward function, where we re-use the attn_weights returned by the first forward function + but with different input. + Args: + x: input, of shape (seq_len, batch_size, embed_dim) + attn_weights: attention weights returned by forward(), of shape (batch_size * num_heads, seq_len, seq_len) + Returns: + output of the same shape as x, i.e. (seq_len, batch_size, embed_dim) + """ + num_heads = self.num_heads + (seq_len, bsz, embed_dim) = x.shape + head_dim = embed_dim // (num_heads * 2) + # v: (tgt_len, bsz, embed_dim // 2) + v = self.in_proj2(x) + v = v.contiguous().view(seq_len, bsz * num_heads, head_dim).transpose(0, 1) + # now v: (bsz * num_heads, seq_len, head_dim) + attn_output = torch.bmm(attn_weights, v) + # attn_output: (bsz * num_heads, seq_len, head_dim) + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, bsz, embed_dim // 2) + ) + # returned value is of shape (seq_len, bsz, embed_dim), like x. + return self.out_proj2(attn_output) + + +class FeedforwardModule(nn.Module): + """Feedforward module in Conformer model. + """ + def __init__(self, + d_model: int, + feedforward_dim: int, + dropout: float): + super(FeedforwardModule, self).__init__() + self.in_proj = nn.Linear(d_model, feedforward_dim) + self.balancer = ActivationBalancer(feedforward_dim, + channel_dim=-1, max_abs=10.0) + self.activation = DoubleSwish() + self.dropout = nn.Dropout(dropout) + self.out_proj = ScaledLinear(feedforward_dim, d_model, + initial_scale=0.01) + + def forward(self, + x: Tensor): + x = self.in_proj(x) + x = self.balancer(x) + x = self.activation(x) + x = self.dropout(x) + x = self.out_proj(x) + return x class ConvolutionModule(nn.Module): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index c084b5556..44dbc95df 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -92,7 +92,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="12,12", + default="7,7", help="Number of conformer encoder layers, comma separated.", )