Remove persistent attention scores.

This commit is contained in:
Daniel Povey 2022-10-12 12:50:00 +08:00
parent 1825336841
commit eb58e6d74b

View File

@ -300,18 +300,15 @@ class ConformerEncoderLayer(nn.Module):
self,
src: Tensor,
pos_emb: Tensor,
attn_scores_in: Optional[Tensor] = None,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
) -> Tensor:
"""
Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
pos_emb: Positional embedding tensor (required).
attn_scores_in: something with the dimension fo attention weights (bsz, len, len, num_heads) that is
passed from layer to layer.
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
batch_split: if not None, this layer will only be applied to
@ -329,10 +326,9 @@ class ConformerEncoderLayer(nn.Module):
src = src + self.feed_forward1(src)
# multi-headed self-attention module
src_att, attn_weights, attn_scores_out = self.self_attn(
src_att, attn_weights = self.self_attn(
src,
pos_emb=pos_emb,
attn_scores_in=attn_scores_in,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)
@ -358,7 +354,7 @@ class ConformerEncoderLayer(nn.Module):
bypass_scale = bypass_scale.clamp(min=0.1, max=1.0)
src = src_orig + delta * self.bypass_scale
return src, attn_scores_out
return src
class ConformerEncoder(nn.Module):
@ -516,7 +512,6 @@ class ConformerEncoder(nn.Module):
output = src
outputs = []
attn_scores = None
rnd_seed = src.numel() + random.randint(0, 1000)
@ -527,10 +522,9 @@ class ConformerEncoder(nn.Module):
for i, mod in enumerate(self.layers):
if i in layers_to_drop:
continue
output, attn_scores = mod(
output = mod(
output,
pos_emb,
attn_scores,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
)
@ -851,8 +845,6 @@ class RelPositionMultiheadAttention(nn.Module):
self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True,
initial_scale=0.05)
self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads))
self.attn_scores_proj_out = nn.Parameter(torch.zeros(num_heads, num_heads))
# linear transformation for positional encoding.
self.linear_pos = nn.Linear(embed_dim, embed_dim // 2, bias=False)
@ -870,7 +862,6 @@ class RelPositionMultiheadAttention(nn.Module):
self,
x: Tensor,
pos_emb: Tensor,
attn_scores_in: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor]:
@ -892,7 +883,6 @@ class RelPositionMultiheadAttention(nn.Module):
the embedding dimension.
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- attn_scores_in: :math:`(N, L, L, num_heads)`
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the position
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
@ -905,19 +895,16 @@ class RelPositionMultiheadAttention(nn.Module):
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
- Returns: (attn_output, attn_weights, attn_scores)
- Returns: (attn_output, attn_weights)
- attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size,
E is the embedding dimension.
- attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads
and S is the sequence length.
- attn_scores: :math:`(N, S, S, H)`, these are the attn weights
before softmax.
"""
x, weights, scores = self.multi_head_attention_forward(
x, weights = self.multi_head_attention_forward(
self.in_balancer(self.in_proj(x)),
pos_emb,
None if attn_scores_in is None else torch.matmul(attn_scores_in, self.attn_scores_proj_in),
self.embed_dim,
self.num_heads,
self.in_proj.weight,
@ -929,15 +916,7 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
)
if attn_scores_in is not None:
attn_scores_out = torch.matmul(scores, self.attn_scores_proj_out)
attn_scores_out = attn_scores_out + attn_scores_in
else:
# Here, add self.attn_scores_proj_in in order to make sure it has
# a grad.
attn_scores_out = torch.matmul(scores, self.attn_scores_proj_out +
self.attn_scores_proj_in)
return x, weights, attn_scores_out
return x, weights
def rel_shift(self, x: Tensor) -> Tensor:
"""Compute relative positional encoding.
@ -968,7 +947,6 @@ class RelPositionMultiheadAttention(nn.Module):
self,
x: Tensor,
pos_emb: Tensor,
attn_scores_in: Optional[Tensor],
embed_dim: int,
num_heads: int,
in_proj_weight: Tensor,
@ -1023,8 +1001,6 @@ class RelPositionMultiheadAttention(nn.Module):
E is the embedding dimension.
- attn_weights: :math:`(N * H, S, S)` where N is the batch size,
H is the num-heads, S is the sequence length.
- attn_scores: :math:`(N, S, S, H)` where N is the batch size,
S is the sequence length and H is the num-heads.
"""
tgt_len, bsz, _ = x.size()
@ -1137,10 +1113,6 @@ class RelPositionMultiheadAttention(nn.Module):
matrix_ac + matrix_bd
) # (batch, head, time1, time2)
attn_scores_out = attn_output_weights.permute(0, 2, 3, 1) # (batch, time1, time2, head)
if attn_scores_in is not None:
attn_output_weights = attn_output_weights + attn_scores_in.permute(0, 3, 1, 2)
attn_output_weights = attn_output_weights.view(
@ -1187,7 +1159,7 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output, out_proj_weight, out_proj_bias
)
return attn_output, attn_output_weights, attn_scores_out
return attn_output, attn_output_weights
def forward2(