mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Fix zipformer attn_output_weights (#774)
* fix attn_output_weights * remove in-place op
This commit is contained in:
parent
fbc1d3b194
commit
65d7192dca
@ -1291,9 +1291,11 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
if attn_mask.dtype == torch.bool:
|
if attn_mask.dtype == torch.bool:
|
||||||
attn_output_weights.masked_fill_(attn_mask, float("-inf"))
|
attn_output_weights = attn_output_weights.masked_fill(
|
||||||
|
attn_mask, float("-inf")
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
attn_output_weights += attn_mask
|
attn_output_weights = attn_output_weights + attn_mask
|
||||||
|
|
||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
attn_output_weights = attn_output_weights.view(
|
attn_output_weights = attn_output_weights.view(
|
||||||
@ -1313,6 +1315,34 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
# only storing the half-precision output for backprop purposes.
|
# only storing the half-precision output for backprop purposes.
|
||||||
attn_output_weights = softmax(attn_output_weights, dim=-1)
|
attn_output_weights = softmax(attn_output_weights, dim=-1)
|
||||||
|
|
||||||
|
# If we are using chunk-wise attention mask and setting a limited
|
||||||
|
# num_left_chunks, the attention may only see the padding values which
|
||||||
|
# will also be masked out by `key_padding_mask`. At this circumstances,
|
||||||
|
# the whole column of `attn_output_weights` will be `-inf`
|
||||||
|
# (i.e. be `nan` after softmax). So we fill `0.0` at the masking
|
||||||
|
# positions to avoid invalid loss value below.
|
||||||
|
if (
|
||||||
|
attn_mask is not None
|
||||||
|
and attn_mask.dtype == torch.bool
|
||||||
|
and key_padding_mask is not None
|
||||||
|
):
|
||||||
|
if attn_mask.size(0) != 1:
|
||||||
|
attn_mask = attn_mask.view(bsz, num_heads, seq_len, seq_len)
|
||||||
|
combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2)
|
||||||
|
else:
|
||||||
|
# attn_mask.shape == (1, tgt_len, src_len)
|
||||||
|
combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze(
|
||||||
|
1
|
||||||
|
).unsqueeze(2)
|
||||||
|
|
||||||
|
attn_output_weights = attn_output_weights.view(
|
||||||
|
bsz, num_heads, seq_len, seq_len
|
||||||
|
)
|
||||||
|
attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0)
|
||||||
|
attn_output_weights = attn_output_weights.view(
|
||||||
|
bsz * num_heads, seq_len, seq_len
|
||||||
|
)
|
||||||
|
|
||||||
attn_output_weights = nn.functional.dropout(
|
attn_output_weights = nn.functional.dropout(
|
||||||
attn_output_weights, p=dropout_p, training=training
|
attn_output_weights, p=dropout_p, training=training
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user