mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Refactor RelPosMultiheadAttention to have 2nd forward function and introduce more modules in conformer encoder layer
This commit is contained in:
parent
f941991331
commit
12323f2fbf
@ -265,28 +265,24 @@ class ConformerEncoderLayer(nn.Module):
|
||||
d_model, nhead, dropout=dropout,
|
||||
)
|
||||
|
||||
self.feed_forward = nn.Sequential(
|
||||
nn.Linear(d_model, feedforward_dim),
|
||||
ActivationBalancer(feedforward_dim,
|
||||
channel_dim=-1, max_abs=10.0),
|
||||
DoubleSwish(),
|
||||
nn.Dropout(dropout),
|
||||
ScaledLinear(feedforward_dim, d_model,
|
||||
initial_scale=0.01),
|
||||
)
|
||||
self.feed_forward1 = FeedforwardModule(d_model,
|
||||
feedforward_dim,
|
||||
dropout)
|
||||
|
||||
self.feed_forward_macaron = nn.Sequential(
|
||||
nn.Linear(d_model, feedforward_dim),
|
||||
ActivationBalancer(feedforward_dim,
|
||||
channel_dim=-1, max_abs=10.0),
|
||||
DoubleSwish(),
|
||||
nn.Dropout(dropout),
|
||||
ScaledLinear(feedforward_dim, d_model,
|
||||
initial_scale=0.01),
|
||||
)
|
||||
self.feed_forward2 = FeedforwardModule(d_model,
|
||||
feedforward_dim,
|
||||
dropout)
|
||||
|
||||
self.conv_module = ConvolutionModule(d_model,
|
||||
cnn_module_kernel)
|
||||
self.feed_forward3 = FeedforwardModule(d_model,
|
||||
feedforward_dim,
|
||||
dropout)
|
||||
|
||||
|
||||
self.conv_module1 = ConvolutionModule(d_model,
|
||||
cnn_module_kernel)
|
||||
|
||||
self.conv_module2 = ConvolutionModule(d_model,
|
||||
cnn_module_kernel)
|
||||
|
||||
self.norm_final = BasicNorm(d_model)
|
||||
|
||||
@ -330,10 +326,10 @@ class ConformerEncoderLayer(nn.Module):
|
||||
src_orig = src
|
||||
|
||||
# macaron style feed forward module
|
||||
src = src + self.feed_forward_macaron(src)
|
||||
src = src + self.feed_forward1(src)
|
||||
|
||||
# multi-headed self-attention module
|
||||
src_att, _, attn_scores_out = self.self_attn(
|
||||
src_att, attn_weights, attn_scores_out = self.self_attn(
|
||||
src,
|
||||
pos_emb=pos_emb,
|
||||
attn_scores_in=attn_scores_in,
|
||||
@ -343,11 +339,16 @@ class ConformerEncoderLayer(nn.Module):
|
||||
src = src + src_att
|
||||
|
||||
# convolution module
|
||||
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
||||
src = src + self.conv_module1(src, src_key_padding_mask=src_key_padding_mask)
|
||||
|
||||
|
||||
# feed forward module
|
||||
src = src + self.feed_forward(src)
|
||||
src = src + self.feed_forward2(src)
|
||||
|
||||
src = src + self.self_attn.forward2(src, attn_weights)
|
||||
|
||||
src = src + self.conv_module2(src, src_key_padding_mask=src_key_padding_mask)
|
||||
|
||||
src = src + self.feed_forward3(src)
|
||||
|
||||
src = self.norm_final(self.balancer(src))
|
||||
|
||||
@ -846,6 +847,10 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
embed_dim // 2, embed_dim, bias=True, initial_scale=0.05
|
||||
)
|
||||
|
||||
self.in_proj2 = nn.Linear(embed_dim, embed_dim // 2, bias=False)
|
||||
self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True,
|
||||
initial_scale=0.05)
|
||||
|
||||
self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads))
|
||||
self.attn_scores_proj_out = nn.Parameter(torch.zeros(num_heads, num_heads))
|
||||
|
||||
@ -867,9 +872,8 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
pos_emb: Tensor,
|
||||
attn_scores_in: Optional[Tensor],
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
x: input to be projected to query, key, value
|
||||
@ -879,7 +883,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
the corresponding value on the attention layer will be ignored. When given
|
||||
a byte mask and a value is non-zero, the corresponding value on the attention
|
||||
layer will be ignored
|
||||
need_weights: output attn_output_weights.
|
||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||
|
||||
@ -902,11 +905,14 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
||||
is provided, it will be added to the attention weight.
|
||||
|
||||
- Outputs:
|
||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||
E is the embedding dimension.
|
||||
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
||||
L is the target sequence length, S is the source sequence length.
|
||||
- Returns: (attn_output, attn_weights, attn_scores)
|
||||
|
||||
- attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size,
|
||||
E is the embedding dimension.
|
||||
- attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads
|
||||
and S is the sequence length.
|
||||
- attn_scores: :math:`(N, S, S, H)`, these are the attn weights
|
||||
before softmax.
|
||||
"""
|
||||
x, weights, scores = self.multi_head_attention_forward(
|
||||
self.in_balancer(self.in_proj(x)),
|
||||
@ -921,7 +927,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
self.out_proj.bias,
|
||||
training=self.training,
|
||||
key_padding_mask=key_padding_mask,
|
||||
need_weights=need_weights,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
if attn_scores_in is not None:
|
||||
@ -973,7 +978,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
out_proj_bias: Tensor,
|
||||
training: bool = True,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
r"""
|
||||
@ -989,7 +993,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
key_padding_mask: if provided, specified padding elements in the key will
|
||||
be ignored by the attention. This is an binary mask. When the value is True,
|
||||
the corresponding value on the attention layer will be filled with -inf.
|
||||
need_weights: output attn_output_weights.
|
||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||
|
||||
@ -1017,9 +1020,11 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
Outputs:
|
||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||
E is the embedding dimension.
|
||||
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
||||
L is the target sequence length, S is the source sequence length.
|
||||
E is the embedding dimension.
|
||||
- attn_weights: :math:`(N * H, S, S)` where N is the batch size,
|
||||
H is the num-heads, S is the sequence length.
|
||||
- attn_scores: :math:`(N, S, S, H)` where N is the batch size,
|
||||
S is the sequence length and H is the num-heads.
|
||||
"""
|
||||
|
||||
tgt_len, bsz, _ = x.size()
|
||||
@ -1182,14 +1187,65 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
attn_output, out_proj_weight, out_proj_bias
|
||||
)
|
||||
|
||||
if need_weights:
|
||||
# average attention weights over heads
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz, num_heads, tgt_len, src_len
|
||||
)
|
||||
return attn_output, attn_output_weights.sum(dim=1) / num_heads, attn_scores_out
|
||||
else:
|
||||
return attn_output, None, attn_scores_out
|
||||
return attn_output, attn_output_weights, attn_scores_out
|
||||
|
||||
|
||||
def forward2(
|
||||
self,
|
||||
x: Tensor,
|
||||
attn_weights: Tensor,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Second forward function, where we re-use the attn_weights returned by the first forward function
|
||||
but with different input.
|
||||
Args:
|
||||
x: input, of shape (seq_len, batch_size, embed_dim)
|
||||
attn_weights: attention weights returned by forward(), of shape (batch_size * num_heads, seq_len, seq_len)
|
||||
Returns:
|
||||
output of the same shape as x, i.e. (seq_len, batch_size, embed_dim)
|
||||
"""
|
||||
num_heads = self.num_heads
|
||||
(seq_len, bsz, embed_dim) = x.shape
|
||||
head_dim = embed_dim // (num_heads * 2)
|
||||
# v: (tgt_len, bsz, embed_dim // 2)
|
||||
v = self.in_proj2(x)
|
||||
v = v.contiguous().view(seq_len, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
# now v: (bsz * num_heads, seq_len, head_dim)
|
||||
attn_output = torch.bmm(attn_weights, v)
|
||||
# attn_output: (bsz * num_heads, seq_len, head_dim)
|
||||
attn_output = (
|
||||
attn_output.transpose(0, 1)
|
||||
.contiguous()
|
||||
.view(seq_len, bsz, embed_dim // 2)
|
||||
)
|
||||
# returned value is of shape (seq_len, bsz, embed_dim), like x.
|
||||
return self.out_proj2(attn_output)
|
||||
|
||||
|
||||
class FeedforwardModule(nn.Module):
|
||||
"""Feedforward module in Conformer model.
|
||||
"""
|
||||
def __init__(self,
|
||||
d_model: int,
|
||||
feedforward_dim: int,
|
||||
dropout: float):
|
||||
super(FeedforwardModule, self).__init__()
|
||||
self.in_proj = nn.Linear(d_model, feedforward_dim)
|
||||
self.balancer = ActivationBalancer(feedforward_dim,
|
||||
channel_dim=-1, max_abs=10.0)
|
||||
self.activation = DoubleSwish()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.out_proj = ScaledLinear(feedforward_dim, d_model,
|
||||
initial_scale=0.01)
|
||||
|
||||
def forward(self,
|
||||
x: Tensor):
|
||||
x = self.in_proj(x)
|
||||
x = self.balancer(x)
|
||||
x = self.activation(x)
|
||||
x = self.dropout(x)
|
||||
x = self.out_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConvolutionModule(nn.Module):
|
||||
|
||||
@ -92,7 +92,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--num-encoder-layers",
|
||||
type=str,
|
||||
default="12,12",
|
||||
default="7,7",
|
||||
help="Number of conformer encoder layers, comma separated.",
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user