mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Implement persistent attention scores
This commit is contained in:
parent
03a77f8ae5
commit
0f85a3c2e5
@ -211,6 +211,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
src: Tensor,
|
src: Tensor,
|
||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
|
attn_scores_in: Optional[Tensor] = None,
|
||||||
src_mask: Optional[Tensor] = None,
|
src_mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
@ -221,6 +222,8 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
src: the sequence to the encoder layer (required).
|
src: the sequence to the encoder layer (required).
|
||||||
pos_emb: Positional embedding tensor (required).
|
pos_emb: Positional embedding tensor (required).
|
||||||
|
attn_scores_in: something with the dimension fo attention weights (bsz * num_heads, len, len) that is
|
||||||
|
passed from layer to layer.
|
||||||
src_mask: the mask for the src sequence (optional).
|
src_mask: the mask for the src sequence (optional).
|
||||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
warmup: controls selective bypass of of layers; if < 1.0, we will
|
warmup: controls selective bypass of of layers; if < 1.0, we will
|
||||||
@ -251,12 +254,13 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
src = src + self.dropout(self.feed_forward_macaron(src))
|
src = src + self.dropout(self.feed_forward_macaron(src))
|
||||||
|
|
||||||
# multi-headed self-attention module
|
# multi-headed self-attention module
|
||||||
src_att = self.self_attn(
|
src_att, _, attn_scores_out = self.self_attn(
|
||||||
src,
|
src,
|
||||||
pos_emb=pos_emb,
|
pos_emb=pos_emb,
|
||||||
|
attn_scores_in=attn_scores_in,
|
||||||
attn_mask=src_mask,
|
attn_mask=src_mask,
|
||||||
key_padding_mask=src_key_padding_mask,
|
key_padding_mask=src_key_padding_mask,
|
||||||
)[0]
|
)
|
||||||
src = src + self.dropout(src_att)
|
src = src + self.dropout(src_att)
|
||||||
|
|
||||||
# convolution module
|
# convolution module
|
||||||
@ -270,7 +274,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
if alpha != 1.0:
|
if alpha != 1.0:
|
||||||
src = alpha * src + (1 - alpha) * src_orig
|
src = alpha * src + (1 - alpha) * src_orig
|
||||||
|
|
||||||
return src
|
return src, attn_scores_out
|
||||||
|
|
||||||
|
|
||||||
class ConformerEncoder(nn.Module):
|
class ConformerEncoder(nn.Module):
|
||||||
@ -338,11 +342,13 @@ class ConformerEncoder(nn.Module):
|
|||||||
output = src
|
output = src
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
|
attn_scores = None
|
||||||
|
|
||||||
for i, mod in enumerate(self.layers):
|
for i, mod in enumerate(self.layers):
|
||||||
output = mod(
|
output, attn_scores = mod(
|
||||||
output,
|
output,
|
||||||
pos_emb,
|
pos_emb,
|
||||||
|
attn_scores,
|
||||||
src_mask=mask,
|
src_mask=mask,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
@ -477,6 +483,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
embed_dim, embed_dim, bias=True, initial_scale=0.5
|
embed_dim, embed_dim, bias=True, initial_scale=0.5
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads))
|
||||||
|
self.attn_scores_proj_out = nn.Parameter(torch.zeros(num_heads, num_heads))
|
||||||
|
|
||||||
# linear transformation for positional encoding.
|
# linear transformation for positional encoding.
|
||||||
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
|
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
|
||||||
# these two learnable bias are used in matrix c and matrix d
|
# these two learnable bias are used in matrix c and matrix d
|
||||||
@ -493,10 +502,11 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
|
attn_scores_in: Optional[Tensor],
|
||||||
key_padding_mask: Optional[Tensor] = None,
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
need_weights: bool = True,
|
need_weights: bool = True,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
x: input to be projected to query, key, value
|
x: input to be projected to query, key, value
|
||||||
@ -516,6 +526,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
the embedding dimension.
|
the embedding dimension.
|
||||||
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
|
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
|
||||||
the embedding dimension.
|
the embedding dimension.
|
||||||
|
- attn_scores_in: :math:`(N, L, L, num_heads)`
|
||||||
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
||||||
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
||||||
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
||||||
@ -534,9 +545,10 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
||||||
L is the target sequence length, S is the source sequence length.
|
L is the target sequence length, S is the source sequence length.
|
||||||
"""
|
"""
|
||||||
return self.multi_head_attention_forward(
|
x, weights, scores = self.multi_head_attention_forward(
|
||||||
self.in_balancer(self.in_proj(x)),
|
self.in_balancer(self.in_proj(x)),
|
||||||
pos_emb,
|
pos_emb,
|
||||||
|
None if attn_scores_in is None else torch.matmul(attn_scores_in, self.attn_scores_proj_in),
|
||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.in_proj.weight,
|
self.in_proj.weight,
|
||||||
@ -549,6 +561,10 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
need_weights=need_weights,
|
need_weights=need_weights,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
)
|
)
|
||||||
|
attn_scores_out = torch.matmul(scores, self.attn_scores_proj_out)
|
||||||
|
if attn_scores_in is not None:
|
||||||
|
attn_scores_out = attn_scores_out + attn_scores_in
|
||||||
|
return x, weights, attn_scores_out
|
||||||
|
|
||||||
def rel_shift(self, x: Tensor) -> Tensor:
|
def rel_shift(self, x: Tensor) -> Tensor:
|
||||||
"""Compute relative positional encoding.
|
"""Compute relative positional encoding.
|
||||||
@ -579,6 +595,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
|
attn_scores_in: Optional[Tensor],
|
||||||
embed_dim: int,
|
embed_dim: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
in_proj_weight: Tensor,
|
in_proj_weight: Tensor,
|
||||||
@ -747,6 +764,12 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
matrix_ac + matrix_bd
|
matrix_ac + matrix_bd
|
||||||
) # (batch, head, time1, time2)
|
) # (batch, head, time1, time2)
|
||||||
|
|
||||||
|
attn_scores_out = attn_output_weights.permute(0, 2, 3, 1) # (batch, time1, time2, head)
|
||||||
|
|
||||||
|
if attn_scores_in is not None:
|
||||||
|
attn_output_weights = attn_output_weights + attn_scores_in.permute(0, 3, 1, 2)
|
||||||
|
|
||||||
|
|
||||||
attn_output_weights = attn_output_weights.view(
|
attn_output_weights = attn_output_weights.view(
|
||||||
bsz * num_heads, tgt_len, -1
|
bsz * num_heads, tgt_len, -1
|
||||||
)
|
)
|
||||||
@ -796,9 +819,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
attn_output_weights = attn_output_weights.view(
|
attn_output_weights = attn_output_weights.view(
|
||||||
bsz, num_heads, tgt_len, src_len
|
bsz, num_heads, tgt_len, src_len
|
||||||
)
|
)
|
||||||
return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
return attn_output, attn_output_weights.sum(dim=1) / num_heads, attn_scores_out
|
||||||
else:
|
else:
|
||||||
return attn_output, None
|
return attn_output, None, attn_scores_out
|
||||||
|
|
||||||
|
|
||||||
class ConvolutionModule(nn.Module):
|
class ConvolutionModule(nn.Module):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user