From 65d7192dca03ba21bff4270add3891c9730491a7 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Mon, 19 Dec 2022 20:10:39 +0800 Subject: [PATCH] Fix zipformer attn_output_weights (#774) * fix attn_output_weights * remove in-place op --- .../pruned_transducer_stateless7/zipformer.py | 34 +++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 71f12e44a..ad3b88df0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1291,9 +1291,11 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask is not None: if attn_mask.dtype == torch.bool: - attn_output_weights.masked_fill_(attn_mask, float("-inf")) + attn_output_weights = attn_output_weights.masked_fill( + attn_mask, float("-inf") + ) else: - attn_output_weights += attn_mask + attn_output_weights = attn_output_weights + attn_mask if key_padding_mask is not None: attn_output_weights = attn_output_weights.view( @@ -1313,6 +1315,34 @@ class RelPositionMultiheadAttention(nn.Module): # only storing the half-precision output for backprop purposes. attn_output_weights = softmax(attn_output_weights, dim=-1) + # If we are using chunk-wise attention mask and setting a limited + # num_left_chunks, the attention may only see the padding values which + # will also be masked out by `key_padding_mask`. At this circumstances, + # the whole column of `attn_output_weights` will be `-inf` + # (i.e. be `nan` after softmax). So we fill `0.0` at the masking + # positions to avoid invalid loss value below. + if ( + attn_mask is not None + and attn_mask.dtype == torch.bool + and key_padding_mask is not None + ): + if attn_mask.size(0) != 1: + attn_mask = attn_mask.view(bsz, num_heads, seq_len, seq_len) + combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) + else: + # attn_mask.shape == (1, tgt_len, src_len) + combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) + + attn_output_weights = attn_output_weights.view( + bsz, num_heads, seq_len, seq_len + ) + attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, seq_len, seq_len + ) + attn_output_weights = nn.functional.dropout( attn_output_weights, p=dropout_p, training=training )