mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Fix zipformer attn_output_weights (#774)
* fix attn_output_weights * remove in-place op
This commit is contained in:
parent
fbc1d3b194
commit
65d7192dca
@ -1291,9 +1291,11 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dtype == torch.bool:
|
||||
attn_output_weights.masked_fill_(attn_mask, float("-inf"))
|
||||
attn_output_weights = attn_output_weights.masked_fill(
|
||||
attn_mask, float("-inf")
|
||||
)
|
||||
else:
|
||||
attn_output_weights += attn_mask
|
||||
attn_output_weights = attn_output_weights + attn_mask
|
||||
|
||||
if key_padding_mask is not None:
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
@ -1313,6 +1315,34 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
# only storing the half-precision output for backprop purposes.
|
||||
attn_output_weights = softmax(attn_output_weights, dim=-1)
|
||||
|
||||
# If we are using chunk-wise attention mask and setting a limited
|
||||
# num_left_chunks, the attention may only see the padding values which
|
||||
# will also be masked out by `key_padding_mask`. At this circumstances,
|
||||
# the whole column of `attn_output_weights` will be `-inf`
|
||||
# (i.e. be `nan` after softmax). So we fill `0.0` at the masking
|
||||
# positions to avoid invalid loss value below.
|
||||
if (
|
||||
attn_mask is not None
|
||||
and attn_mask.dtype == torch.bool
|
||||
and key_padding_mask is not None
|
||||
):
|
||||
if attn_mask.size(0) != 1:
|
||||
attn_mask = attn_mask.view(bsz, num_heads, seq_len, seq_len)
|
||||
combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2)
|
||||
else:
|
||||
# attn_mask.shape == (1, tgt_len, src_len)
|
||||
combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze(
|
||||
1
|
||||
).unsqueeze(2)
|
||||
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz, num_heads, seq_len, seq_len
|
||||
)
|
||||
attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0)
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz * num_heads, seq_len, seq_len
|
||||
)
|
||||
|
||||
attn_output_weights = nn.functional.dropout(
|
||||
attn_output_weights, p=dropout_p, training=training
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user