From 0f85a3c2e5b71f2600a171ebbe064bcee6ce8cea Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 22 Sep 2022 18:47:16 +0800 Subject: [PATCH] Implement persistent attention scores --- .../pruned_transducer_stateless7/conformer.py | 39 +++++++++++++++---- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 182b78eee..92f3f2dc7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -211,6 +211,7 @@ class ConformerEncoderLayer(nn.Module): self, src: Tensor, pos_emb: Tensor, + attn_scores_in: Optional[Tensor] = None, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, warmup: float = 1.0, @@ -221,6 +222,8 @@ class ConformerEncoderLayer(nn.Module): Args: src: the sequence to the encoder layer (required). pos_emb: Positional embedding tensor (required). + attn_scores_in: something with the dimension fo attention weights (bsz * num_heads, len, len) that is + passed from layer to layer. src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). warmup: controls selective bypass of of layers; if < 1.0, we will @@ -251,12 +254,13 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(self.feed_forward_macaron(src)) # multi-headed self-attention module - src_att = self.self_attn( + src_att, _, attn_scores_out = self.self_attn( src, pos_emb=pos_emb, + attn_scores_in=attn_scores_in, attn_mask=src_mask, key_padding_mask=src_key_padding_mask, - )[0] + ) src = src + self.dropout(src_att) # convolution module @@ -270,7 +274,7 @@ class ConformerEncoderLayer(nn.Module): if alpha != 1.0: src = alpha * src + (1 - alpha) * src_orig - return src + return src, attn_scores_out class ConformerEncoder(nn.Module): @@ -338,11 +342,13 @@ class ConformerEncoder(nn.Module): output = src outputs = [] + attn_scores = None for i, mod in enumerate(self.layers): - output = mod( + output, attn_scores = mod( output, pos_emb, + attn_scores, src_mask=mask, src_key_padding_mask=src_key_padding_mask, warmup=warmup, @@ -477,6 +483,9 @@ class RelPositionMultiheadAttention(nn.Module): embed_dim, embed_dim, bias=True, initial_scale=0.5 ) + self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads)) + self.attn_scores_proj_out = nn.Parameter(torch.zeros(num_heads, num_heads)) + # linear transformation for positional encoding. self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) # these two learnable bias are used in matrix c and matrix d @@ -493,10 +502,11 @@ class RelPositionMultiheadAttention(nn.Module): self, x: Tensor, pos_emb: Tensor, + attn_scores_in: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Optional[Tensor]]: + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: r""" Args: x: input to be projected to query, key, value @@ -516,6 +526,7 @@ class RelPositionMultiheadAttention(nn.Module): the embedding dimension. - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. + - attn_scores_in: :math:`(N, L, L, num_heads)` - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the @@ -534,9 +545,10 @@ class RelPositionMultiheadAttention(nn.Module): - attn_output_weights: :math:`(N, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. """ - return self.multi_head_attention_forward( + x, weights, scores = self.multi_head_attention_forward( self.in_balancer(self.in_proj(x)), pos_emb, + None if attn_scores_in is None else torch.matmul(attn_scores_in, self.attn_scores_proj_in), self.embed_dim, self.num_heads, self.in_proj.weight, @@ -549,6 +561,10 @@ class RelPositionMultiheadAttention(nn.Module): need_weights=need_weights, attn_mask=attn_mask, ) + attn_scores_out = torch.matmul(scores, self.attn_scores_proj_out) + if attn_scores_in is not None: + attn_scores_out = attn_scores_out + attn_scores_in + return x, weights, attn_scores_out def rel_shift(self, x: Tensor) -> Tensor: """Compute relative positional encoding. @@ -579,6 +595,7 @@ class RelPositionMultiheadAttention(nn.Module): self, x: Tensor, pos_emb: Tensor, + attn_scores_in: Optional[Tensor], embed_dim: int, num_heads: int, in_proj_weight: Tensor, @@ -747,6 +764,12 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) # (batch, head, time1, time2) + attn_scores_out = attn_output_weights.permute(0, 2, 3, 1) # (batch, time1, time2, head) + + if attn_scores_in is not None: + attn_output_weights = attn_output_weights + attn_scores_in.permute(0, 3, 1, 2) + + attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, -1 ) @@ -796,9 +819,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len ) - return attn_output, attn_output_weights.sum(dim=1) / num_heads + return attn_output, attn_output_weights.sum(dim=1) / num_heads, attn_scores_out else: - return attn_output, None + return attn_output, None, attn_scores_out class ConvolutionModule(nn.Module):