Implement persistent attention scores

This commit is contained in:
Daniel Povey 2022-09-22 18:47:16 +08:00
parent 03a77f8ae5
commit 0f85a3c2e5

View File

@ -211,6 +211,7 @@ class ConformerEncoderLayer(nn.Module):
self, self,
src: Tensor, src: Tensor,
pos_emb: Tensor, pos_emb: Tensor,
attn_scores_in: Optional[Tensor] = None,
src_mask: Optional[Tensor] = None, src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0, warmup: float = 1.0,
@ -221,6 +222,8 @@ class ConformerEncoderLayer(nn.Module):
Args: Args:
src: the sequence to the encoder layer (required). src: the sequence to the encoder layer (required).
pos_emb: Positional embedding tensor (required). pos_emb: Positional embedding tensor (required).
attn_scores_in: something with the dimension fo attention weights (bsz * num_heads, len, len) that is
passed from layer to layer.
src_mask: the mask for the src sequence (optional). src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional). src_key_padding_mask: the mask for the src keys per batch (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will warmup: controls selective bypass of of layers; if < 1.0, we will
@ -251,12 +254,13 @@ class ConformerEncoderLayer(nn.Module):
src = src + self.dropout(self.feed_forward_macaron(src)) src = src + self.dropout(self.feed_forward_macaron(src))
# multi-headed self-attention module # multi-headed self-attention module
src_att = self.self_attn( src_att, _, attn_scores_out = self.self_attn(
src, src,
pos_emb=pos_emb, pos_emb=pos_emb,
attn_scores_in=attn_scores_in,
attn_mask=src_mask, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask, key_padding_mask=src_key_padding_mask,
)[0] )
src = src + self.dropout(src_att) src = src + self.dropout(src_att)
# convolution module # convolution module
@ -270,7 +274,7 @@ class ConformerEncoderLayer(nn.Module):
if alpha != 1.0: if alpha != 1.0:
src = alpha * src + (1 - alpha) * src_orig src = alpha * src + (1 - alpha) * src_orig
return src return src, attn_scores_out
class ConformerEncoder(nn.Module): class ConformerEncoder(nn.Module):
@ -338,11 +342,13 @@ class ConformerEncoder(nn.Module):
output = src output = src
outputs = [] outputs = []
attn_scores = None
for i, mod in enumerate(self.layers): for i, mod in enumerate(self.layers):
output = mod( output, attn_scores = mod(
output, output,
pos_emb, pos_emb,
attn_scores,
src_mask=mask, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
warmup=warmup, warmup=warmup,
@ -477,6 +483,9 @@ class RelPositionMultiheadAttention(nn.Module):
embed_dim, embed_dim, bias=True, initial_scale=0.5 embed_dim, embed_dim, bias=True, initial_scale=0.5
) )
self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads))
self.attn_scores_proj_out = nn.Parameter(torch.zeros(num_heads, num_heads))
# linear transformation for positional encoding. # linear transformation for positional encoding.
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
# these two learnable bias are used in matrix c and matrix d # these two learnable bias are used in matrix c and matrix d
@ -493,10 +502,11 @@ class RelPositionMultiheadAttention(nn.Module):
self, self,
x: Tensor, x: Tensor,
pos_emb: Tensor, pos_emb: Tensor,
attn_scores_in: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True, need_weights: bool = True,
attn_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
r""" r"""
Args: Args:
x: input to be projected to query, key, value x: input to be projected to query, key, value
@ -516,6 +526,7 @@ class RelPositionMultiheadAttention(nn.Module):
the embedding dimension. the embedding dimension.
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension. the embedding dimension.
- attn_scores_in: :math:`(N, L, L, num_heads)`
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the position If a ByteTensor is provided, the non-zero positions will be ignored while the position
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
@ -534,9 +545,10 @@ class RelPositionMultiheadAttention(nn.Module):
- attn_output_weights: :math:`(N, L, S)` where N is the batch size, - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length. L is the target sequence length, S is the source sequence length.
""" """
return self.multi_head_attention_forward( x, weights, scores = self.multi_head_attention_forward(
self.in_balancer(self.in_proj(x)), self.in_balancer(self.in_proj(x)),
pos_emb, pos_emb,
None if attn_scores_in is None else torch.matmul(attn_scores_in, self.attn_scores_proj_in),
self.embed_dim, self.embed_dim,
self.num_heads, self.num_heads,
self.in_proj.weight, self.in_proj.weight,
@ -549,6 +561,10 @@ class RelPositionMultiheadAttention(nn.Module):
need_weights=need_weights, need_weights=need_weights,
attn_mask=attn_mask, attn_mask=attn_mask,
) )
attn_scores_out = torch.matmul(scores, self.attn_scores_proj_out)
if attn_scores_in is not None:
attn_scores_out = attn_scores_out + attn_scores_in
return x, weights, attn_scores_out
def rel_shift(self, x: Tensor) -> Tensor: def rel_shift(self, x: Tensor) -> Tensor:
"""Compute relative positional encoding. """Compute relative positional encoding.
@ -579,6 +595,7 @@ class RelPositionMultiheadAttention(nn.Module):
self, self,
x: Tensor, x: Tensor,
pos_emb: Tensor, pos_emb: Tensor,
attn_scores_in: Optional[Tensor],
embed_dim: int, embed_dim: int,
num_heads: int, num_heads: int,
in_proj_weight: Tensor, in_proj_weight: Tensor,
@ -747,6 +764,12 @@ class RelPositionMultiheadAttention(nn.Module):
matrix_ac + matrix_bd matrix_ac + matrix_bd
) # (batch, head, time1, time2) ) # (batch, head, time1, time2)
attn_scores_out = attn_output_weights.permute(0, 2, 3, 1) # (batch, time1, time2, head)
if attn_scores_in is not None:
attn_output_weights = attn_output_weights + attn_scores_in.permute(0, 3, 1, 2)
attn_output_weights = attn_output_weights.view( attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, -1 bsz * num_heads, tgt_len, -1
) )
@ -796,9 +819,9 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output_weights = attn_output_weights.view( attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len bsz, num_heads, tgt_len, src_len
) )
return attn_output, attn_output_weights.sum(dim=1) / num_heads return attn_output, attn_output_weights.sum(dim=1) / num_heads, attn_scores_out
else: else:
return attn_output, None return attn_output, None, attn_scores_out
class ConvolutionModule(nn.Module): class ConvolutionModule(nn.Module):