Refactor RelPosMultiheadAttention to have 2nd forward function and introduce more modules in conformer encoder layer

This commit is contained in:
Daniel Povey 2022-10-10 15:27:18 +08:00
parent f941991331
commit 12323f2fbf
2 changed files with 104 additions and 48 deletions

View File

@ -265,28 +265,24 @@ class ConformerEncoderLayer(nn.Module):
d_model, nhead, dropout=dropout, d_model, nhead, dropout=dropout,
) )
self.feed_forward = nn.Sequential( self.feed_forward1 = FeedforwardModule(d_model,
nn.Linear(d_model, feedforward_dim), feedforward_dim,
ActivationBalancer(feedforward_dim, dropout)
channel_dim=-1, max_abs=10.0),
DoubleSwish(),
nn.Dropout(dropout),
ScaledLinear(feedforward_dim, d_model,
initial_scale=0.01),
)
self.feed_forward_macaron = nn.Sequential( self.feed_forward2 = FeedforwardModule(d_model,
nn.Linear(d_model, feedforward_dim), feedforward_dim,
ActivationBalancer(feedforward_dim, dropout)
channel_dim=-1, max_abs=10.0),
DoubleSwish(),
nn.Dropout(dropout),
ScaledLinear(feedforward_dim, d_model,
initial_scale=0.01),
)
self.conv_module = ConvolutionModule(d_model, self.feed_forward3 = FeedforwardModule(d_model,
cnn_module_kernel) feedforward_dim,
dropout)
self.conv_module1 = ConvolutionModule(d_model,
cnn_module_kernel)
self.conv_module2 = ConvolutionModule(d_model,
cnn_module_kernel)
self.norm_final = BasicNorm(d_model) self.norm_final = BasicNorm(d_model)
@ -330,10 +326,10 @@ class ConformerEncoderLayer(nn.Module):
src_orig = src src_orig = src
# macaron style feed forward module # macaron style feed forward module
src = src + self.feed_forward_macaron(src) src = src + self.feed_forward1(src)
# multi-headed self-attention module # multi-headed self-attention module
src_att, _, attn_scores_out = self.self_attn( src_att, attn_weights, attn_scores_out = self.self_attn(
src, src,
pos_emb=pos_emb, pos_emb=pos_emb,
attn_scores_in=attn_scores_in, attn_scores_in=attn_scores_in,
@ -343,11 +339,16 @@ class ConformerEncoderLayer(nn.Module):
src = src + src_att src = src + src_att
# convolution module # convolution module
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) src = src + self.conv_module1(src, src_key_padding_mask=src_key_padding_mask)
# feed forward module src = src + self.feed_forward2(src)
src = src + self.feed_forward(src)
src = src + self.self_attn.forward2(src, attn_weights)
src = src + self.conv_module2(src, src_key_padding_mask=src_key_padding_mask)
src = src + self.feed_forward3(src)
src = self.norm_final(self.balancer(src)) src = self.norm_final(self.balancer(src))
@ -846,6 +847,10 @@ class RelPositionMultiheadAttention(nn.Module):
embed_dim // 2, embed_dim, bias=True, initial_scale=0.05 embed_dim // 2, embed_dim, bias=True, initial_scale=0.05
) )
self.in_proj2 = nn.Linear(embed_dim, embed_dim // 2, bias=False)
self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True,
initial_scale=0.05)
self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads)) self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads))
self.attn_scores_proj_out = nn.Parameter(torch.zeros(num_heads, num_heads)) self.attn_scores_proj_out = nn.Parameter(torch.zeros(num_heads, num_heads))
@ -867,9 +872,8 @@ class RelPositionMultiheadAttention(nn.Module):
pos_emb: Tensor, pos_emb: Tensor,
attn_scores_in: Optional[Tensor], attn_scores_in: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: ) -> Tuple[Tensor, Tensor, Tensor]:
r""" r"""
Args: Args:
x: input to be projected to query, key, value x: input to be projected to query, key, value
@ -879,7 +883,6 @@ class RelPositionMultiheadAttention(nn.Module):
the corresponding value on the attention layer will be ignored. When given the corresponding value on the attention layer will be ignored. When given
a byte mask and a value is non-zero, the corresponding value on the attention a byte mask and a value is non-zero, the corresponding value on the attention
layer will be ignored layer will be ignored
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch. the batches while a 3D mask allows to specify a different mask for the entries of each batch.
@ -902,11 +905,14 @@ class RelPositionMultiheadAttention(nn.Module):
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight. is provided, it will be added to the attention weight.
- Outputs: - Returns: (attn_output, attn_weights, attn_scores)
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension. - attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size,
- attn_output_weights: :math:`(N, L, S)` where N is the batch size, E is the embedding dimension.
L is the target sequence length, S is the source sequence length. - attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads
and S is the sequence length.
- attn_scores: :math:`(N, S, S, H)`, these are the attn weights
before softmax.
""" """
x, weights, scores = self.multi_head_attention_forward( x, weights, scores = self.multi_head_attention_forward(
self.in_balancer(self.in_proj(x)), self.in_balancer(self.in_proj(x)),
@ -921,7 +927,6 @@ class RelPositionMultiheadAttention(nn.Module):
self.out_proj.bias, self.out_proj.bias,
training=self.training, training=self.training,
key_padding_mask=key_padding_mask, key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask, attn_mask=attn_mask,
) )
if attn_scores_in is not None: if attn_scores_in is not None:
@ -973,7 +978,6 @@ class RelPositionMultiheadAttention(nn.Module):
out_proj_bias: Tensor, out_proj_bias: Tensor,
training: bool = True, training: bool = True,
key_padding_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:
r""" r"""
@ -989,7 +993,6 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask: if provided, specified padding elements in the key will key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. This is an binary mask. When the value is True, be ignored by the attention. This is an binary mask. When the value is True,
the corresponding value on the attention layer will be filled with -inf. the corresponding value on the attention layer will be filled with -inf.
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch. the batches while a 3D mask allows to specify a different mask for the entries of each batch.
@ -1017,9 +1020,11 @@ class RelPositionMultiheadAttention(nn.Module):
Outputs: Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension. E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size, - attn_weights: :math:`(N * H, S, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length. H is the num-heads, S is the sequence length.
- attn_scores: :math:`(N, S, S, H)` where N is the batch size,
S is the sequence length and H is the num-heads.
""" """
tgt_len, bsz, _ = x.size() tgt_len, bsz, _ = x.size()
@ -1182,14 +1187,65 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output, out_proj_weight, out_proj_bias attn_output, out_proj_weight, out_proj_bias
) )
if need_weights: return attn_output, attn_output_weights, attn_scores_out
# average attention weights over heads
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len def forward2(
) self,
return attn_output, attn_output_weights.sum(dim=1) / num_heads, attn_scores_out x: Tensor,
else: attn_weights: Tensor,
return attn_output, None, attn_scores_out ) -> Tensor:
"""
Second forward function, where we re-use the attn_weights returned by the first forward function
but with different input.
Args:
x: input, of shape (seq_len, batch_size, embed_dim)
attn_weights: attention weights returned by forward(), of shape (batch_size * num_heads, seq_len, seq_len)
Returns:
output of the same shape as x, i.e. (seq_len, batch_size, embed_dim)
"""
num_heads = self.num_heads
(seq_len, bsz, embed_dim) = x.shape
head_dim = embed_dim // (num_heads * 2)
# v: (tgt_len, bsz, embed_dim // 2)
v = self.in_proj2(x)
v = v.contiguous().view(seq_len, bsz * num_heads, head_dim).transpose(0, 1)
# now v: (bsz * num_heads, seq_len, head_dim)
attn_output = torch.bmm(attn_weights, v)
# attn_output: (bsz * num_heads, seq_len, head_dim)
attn_output = (
attn_output.transpose(0, 1)
.contiguous()
.view(seq_len, bsz, embed_dim // 2)
)
# returned value is of shape (seq_len, bsz, embed_dim), like x.
return self.out_proj2(attn_output)
class FeedforwardModule(nn.Module):
"""Feedforward module in Conformer model.
"""
def __init__(self,
d_model: int,
feedforward_dim: int,
dropout: float):
super(FeedforwardModule, self).__init__()
self.in_proj = nn.Linear(d_model, feedforward_dim)
self.balancer = ActivationBalancer(feedforward_dim,
channel_dim=-1, max_abs=10.0)
self.activation = DoubleSwish()
self.dropout = nn.Dropout(dropout)
self.out_proj = ScaledLinear(feedforward_dim, d_model,
initial_scale=0.01)
def forward(self,
x: Tensor):
x = self.in_proj(x)
x = self.balancer(x)
x = self.activation(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class ConvolutionModule(nn.Module): class ConvolutionModule(nn.Module):

View File

@ -92,7 +92,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--num-encoder-layers", "--num-encoder-layers",
type=str, type=str,
default="12,12", default="7,7",
help="Number of conformer encoder layers, comma separated.", help="Number of conformer encoder layers, comma separated.",
) )