Implement persistent attention scores

This commit is contained in:
Daniel Povey 2022-09-22 18:47:16 +08:00
parent 03a77f8ae5
commit 0f85a3c2e5

View File

@ -211,6 +211,7 @@ class ConformerEncoderLayer(nn.Module):
self,
src: Tensor,
pos_emb: Tensor,
attn_scores_in: Optional[Tensor] = None,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0,
@ -221,6 +222,8 @@ class ConformerEncoderLayer(nn.Module):
Args:
src: the sequence to the encoder layer (required).
pos_emb: Positional embedding tensor (required).
attn_scores_in: something with the dimension fo attention weights (bsz * num_heads, len, len) that is
passed from layer to layer.
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will
@ -251,12 +254,13 @@ class ConformerEncoderLayer(nn.Module):
src = src + self.dropout(self.feed_forward_macaron(src))
# multi-headed self-attention module
src_att = self.self_attn(
src_att, _, attn_scores_out = self.self_attn(
src,
pos_emb=pos_emb,
attn_scores_in=attn_scores_in,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)[0]
)
src = src + self.dropout(src_att)
# convolution module
@ -270,7 +274,7 @@ class ConformerEncoderLayer(nn.Module):
if alpha != 1.0:
src = alpha * src + (1 - alpha) * src_orig
return src
return src, attn_scores_out
class ConformerEncoder(nn.Module):
@ -338,11 +342,13 @@ class ConformerEncoder(nn.Module):
output = src
outputs = []
attn_scores = None
for i, mod in enumerate(self.layers):
output = mod(
output, attn_scores = mod(
output,
pos_emb,
attn_scores,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
@ -477,6 +483,9 @@ class RelPositionMultiheadAttention(nn.Module):
embed_dim, embed_dim, bias=True, initial_scale=0.5
)
self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads))
self.attn_scores_proj_out = nn.Parameter(torch.zeros(num_heads, num_heads))
# linear transformation for positional encoding.
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
# these two learnable bias are used in matrix c and matrix d
@ -493,10 +502,11 @@ class RelPositionMultiheadAttention(nn.Module):
self,
x: Tensor,
pos_emb: Tensor,
attn_scores_in: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
r"""
Args:
x: input to be projected to query, key, value
@ -516,6 +526,7 @@ class RelPositionMultiheadAttention(nn.Module):
the embedding dimension.
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- attn_scores_in: :math:`(N, L, L, num_heads)`
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the position
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
@ -534,9 +545,10 @@ class RelPositionMultiheadAttention(nn.Module):
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
"""
return self.multi_head_attention_forward(
x, weights, scores = self.multi_head_attention_forward(
self.in_balancer(self.in_proj(x)),
pos_emb,
None if attn_scores_in is None else torch.matmul(attn_scores_in, self.attn_scores_proj_in),
self.embed_dim,
self.num_heads,
self.in_proj.weight,
@ -549,6 +561,10 @@ class RelPositionMultiheadAttention(nn.Module):
need_weights=need_weights,
attn_mask=attn_mask,
)
attn_scores_out = torch.matmul(scores, self.attn_scores_proj_out)
if attn_scores_in is not None:
attn_scores_out = attn_scores_out + attn_scores_in
return x, weights, attn_scores_out
def rel_shift(self, x: Tensor) -> Tensor:
"""Compute relative positional encoding.
@ -579,6 +595,7 @@ class RelPositionMultiheadAttention(nn.Module):
self,
x: Tensor,
pos_emb: Tensor,
attn_scores_in: Optional[Tensor],
embed_dim: int,
num_heads: int,
in_proj_weight: Tensor,
@ -747,6 +764,12 @@ class RelPositionMultiheadAttention(nn.Module):
matrix_ac + matrix_bd
) # (batch, head, time1, time2)
attn_scores_out = attn_output_weights.permute(0, 2, 3, 1) # (batch, time1, time2, head)
if attn_scores_in is not None:
attn_output_weights = attn_output_weights + attn_scores_in.permute(0, 3, 1, 2)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, -1
)
@ -796,9 +819,9 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len
)
return attn_output, attn_output_weights.sum(dim=1) / num_heads
return attn_output, attn_output_weights.sum(dim=1) / num_heads, attn_scores_out
else:
return attn_output, None
return attn_output, None, attn_scores_out
class ConvolutionModule(nn.Module):