From 3364d9863c7816da59b9df30f58784617ab23b84 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Sun, 29 Jan 2023 14:46:14 +0800 Subject: [PATCH] apply penalize_abs_values_gt on attn_output_weights --- .../ASR/zipformer_ctc_attn/attention_decoder.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer_ctc_attn/attention_decoder.py b/egs/librispeech/ASR/zipformer_ctc_attn/attention_decoder.py index 8bb96c002..7b232d0fe 100644 --- a/egs/librispeech/ASR/zipformer_ctc_attn/attention_decoder.py +++ b/egs/librispeech/ASR/zipformer_ctc_attn/attention_decoder.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey, Zengwei Yao) +# Copyright 2022 Xiaomi Corp. (authors: Zengwei Yao) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -14,6 +14,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# The model structure is modified from Daniel Povey's Zipformer +# https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py + import itertools import logging import math @@ -605,6 +609,16 @@ class MultiHeadedAttention(nn.Module): # (batch, head, time1, time2) attn_output_weights = torch.matmul(q, k) + # This is a harder way of limiting the attention scores to not be too large. + # It incurs a penalty if any of them has an absolute value greater than 50.0. + # this should be outside the normal range of the attention scores. We use + # this mechanism instead of, say, a limit on entropy, because once the entropy + # gets very small gradients through the softmax can become very small, and + # some mechanisms like that become ineffective. + attn_output_weights = penalize_abs_values_gt( + attn_output_weights, limit=25.0, penalty=1.0e-04 + ) + if mask is not None: attn_output_weights = attn_output_weights.masked_fill( mask.unsqueeze(1), float("-inf")