From eb58e6d74b70e80ea2bdc3d1889ad8d0ea3c1b4e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 12 Oct 2022 12:50:00 +0800 Subject: [PATCH] Remove persistent attention scores. --- .../pruned_transducer_stateless7/conformer.py | 44 ++++--------------- 1 file changed, 8 insertions(+), 36 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index cef8d1b18..d15da06c6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -300,18 +300,15 @@ class ConformerEncoderLayer(nn.Module): self, src: Tensor, pos_emb: Tensor, - attn_scores_in: Optional[Tensor] = None, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: + ) -> Tensor: """ Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). pos_emb: Positional embedding tensor (required). - attn_scores_in: something with the dimension fo attention weights (bsz, len, len, num_heads) that is - passed from layer to layer. src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). batch_split: if not None, this layer will only be applied to @@ -329,10 +326,9 @@ class ConformerEncoderLayer(nn.Module): src = src + self.feed_forward1(src) # multi-headed self-attention module - src_att, attn_weights, attn_scores_out = self.self_attn( + src_att, attn_weights = self.self_attn( src, pos_emb=pos_emb, - attn_scores_in=attn_scores_in, attn_mask=src_mask, key_padding_mask=src_key_padding_mask, ) @@ -358,7 +354,7 @@ class ConformerEncoderLayer(nn.Module): bypass_scale = bypass_scale.clamp(min=0.1, max=1.0) src = src_orig + delta * self.bypass_scale - return src, attn_scores_out + return src class ConformerEncoder(nn.Module): @@ -516,7 +512,6 @@ class ConformerEncoder(nn.Module): output = src outputs = [] - attn_scores = None rnd_seed = src.numel() + random.randint(0, 1000) @@ -527,10 +522,9 @@ class ConformerEncoder(nn.Module): for i, mod in enumerate(self.layers): if i in layers_to_drop: continue - output, attn_scores = mod( + output = mod( output, pos_emb, - attn_scores, src_mask=mask, src_key_padding_mask=src_key_padding_mask, ) @@ -851,8 +845,6 @@ class RelPositionMultiheadAttention(nn.Module): self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True, initial_scale=0.05) - self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads)) - self.attn_scores_proj_out = nn.Parameter(torch.zeros(num_heads, num_heads)) # linear transformation for positional encoding. self.linear_pos = nn.Linear(embed_dim, embed_dim // 2, bias=False) @@ -870,7 +862,6 @@ class RelPositionMultiheadAttention(nn.Module): self, x: Tensor, pos_emb: Tensor, - attn_scores_in: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor]: @@ -892,7 +883,6 @@ class RelPositionMultiheadAttention(nn.Module): the embedding dimension. - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - - attn_scores_in: :math:`(N, L, L, num_heads)` - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the @@ -905,19 +895,16 @@ class RelPositionMultiheadAttention(nn.Module): is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight. - - Returns: (attn_output, attn_weights, attn_scores) + - Returns: (attn_output, attn_weights) - attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size, E is the embedding dimension. - attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads and S is the sequence length. - - attn_scores: :math:`(N, S, S, H)`, these are the attn weights - before softmax. """ - x, weights, scores = self.multi_head_attention_forward( + x, weights = self.multi_head_attention_forward( self.in_balancer(self.in_proj(x)), pos_emb, - None if attn_scores_in is None else torch.matmul(attn_scores_in, self.attn_scores_proj_in), self.embed_dim, self.num_heads, self.in_proj.weight, @@ -929,15 +916,7 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask=key_padding_mask, attn_mask=attn_mask, ) - if attn_scores_in is not None: - attn_scores_out = torch.matmul(scores, self.attn_scores_proj_out) - attn_scores_out = attn_scores_out + attn_scores_in - else: - # Here, add self.attn_scores_proj_in in order to make sure it has - # a grad. - attn_scores_out = torch.matmul(scores, self.attn_scores_proj_out + - self.attn_scores_proj_in) - return x, weights, attn_scores_out + return x, weights def rel_shift(self, x: Tensor) -> Tensor: """Compute relative positional encoding. @@ -968,7 +947,6 @@ class RelPositionMultiheadAttention(nn.Module): self, x: Tensor, pos_emb: Tensor, - attn_scores_in: Optional[Tensor], embed_dim: int, num_heads: int, in_proj_weight: Tensor, @@ -1023,8 +1001,6 @@ class RelPositionMultiheadAttention(nn.Module): E is the embedding dimension. - attn_weights: :math:`(N * H, S, S)` where N is the batch size, H is the num-heads, S is the sequence length. - - attn_scores: :math:`(N, S, S, H)` where N is the batch size, - S is the sequence length and H is the num-heads. """ tgt_len, bsz, _ = x.size() @@ -1137,10 +1113,6 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) # (batch, head, time1, time2) - attn_scores_out = attn_output_weights.permute(0, 2, 3, 1) # (batch, time1, time2, head) - - if attn_scores_in is not None: - attn_output_weights = attn_output_weights + attn_scores_in.permute(0, 3, 1, 2) attn_output_weights = attn_output_weights.view( @@ -1187,7 +1159,7 @@ class RelPositionMultiheadAttention(nn.Module): attn_output, out_proj_weight, out_proj_bias ) - return attn_output, attn_output_weights, attn_scores_out + return attn_output, attn_output_weights def forward2(