apply penalize_abs_values_gt on attn_output_weights

This commit is contained in:
yaozengwei 2023-01-29 14:46:14 +08:00
parent 24d6565126
commit 3364d9863c

View File

@ -1,5 +1,5 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey, Zengwei Yao)
# Copyright 2022 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -14,6 +14,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The model structure is modified from Daniel Povey's Zipformer
# https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
import itertools
import logging
import math
@ -605,6 +609,16 @@ class MultiHeadedAttention(nn.Module):
# (batch, head, time1, time2)
attn_output_weights = torch.matmul(q, k)
# This is a harder way of limiting the attention scores to not be too large.
# It incurs a penalty if any of them has an absolute value greater than 50.0.
# this should be outside the normal range of the attention scores. We use
# this mechanism instead of, say, a limit on entropy, because once the entropy
# gets very small gradients through the softmax can become very small, and
# some mechanisms like that become ineffective.
attn_output_weights = penalize_abs_values_gt(
attn_output_weights, limit=25.0, penalty=1.0e-04
)
if mask is not None:
attn_output_weights = attn_output_weights.masked_fill(
mask.unsqueeze(1), float("-inf")