Fix zipformer attn_output_weights (#774)

* fix attn_output_weights

* remove in-place op
This commit is contained in:
Zengwei Yao 2022-12-19 20:10:39 +08:00 committed by GitHub
parent fbc1d3b194
commit 65d7192dca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1291,9 +1291,11 @@ class RelPositionMultiheadAttention(nn.Module):
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_output_weights.masked_fill_(attn_mask, float("-inf"))
attn_output_weights = attn_output_weights.masked_fill(
attn_mask, float("-inf")
)
else:
attn_output_weights += attn_mask
attn_output_weights = attn_output_weights + attn_mask
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(
@ -1313,6 +1315,34 @@ class RelPositionMultiheadAttention(nn.Module):
# only storing the half-precision output for backprop purposes.
attn_output_weights = softmax(attn_output_weights, dim=-1)
# If we are using chunk-wise attention mask and setting a limited
# num_left_chunks, the attention may only see the padding values which
# will also be masked out by `key_padding_mask`. At this circumstances,
# the whole column of `attn_output_weights` will be `-inf`
# (i.e. be `nan` after softmax). So we fill `0.0` at the masking
# positions to avoid invalid loss value below.
if (
attn_mask is not None
and attn_mask.dtype == torch.bool
and key_padding_mask is not None
):
if attn_mask.size(0) != 1:
attn_mask = attn_mask.view(bsz, num_heads, seq_len, seq_len)
combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2)
else:
# attn_mask.shape == (1, tgt_len, src_len)
combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze(
1
).unsqueeze(2)
attn_output_weights = attn_output_weights.view(
bsz, num_heads, seq_len, seq_len
)
attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, seq_len, seq_len
)
attn_output_weights = nn.functional.dropout(
attn_output_weights, p=dropout_p, training=training
)