mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
apply penalize_abs_values_gt on attn_output_weights
This commit is contained in:
parent
24d6565126
commit
3364d9863c
@ -1,5 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey, Zengwei Yao)
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -14,6 +14,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# The model structure is modified from Daniel Povey's Zipformer
|
||||
# https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
@ -605,6 +609,16 @@ class MultiHeadedAttention(nn.Module):
|
||||
# (batch, head, time1, time2)
|
||||
attn_output_weights = torch.matmul(q, k)
|
||||
|
||||
# This is a harder way of limiting the attention scores to not be too large.
|
||||
# It incurs a penalty if any of them has an absolute value greater than 50.0.
|
||||
# this should be outside the normal range of the attention scores. We use
|
||||
# this mechanism instead of, say, a limit on entropy, because once the entropy
|
||||
# gets very small gradients through the softmax can become very small, and
|
||||
# some mechanisms like that become ineffective.
|
||||
attn_output_weights = penalize_abs_values_gt(
|
||||
attn_output_weights, limit=25.0, penalty=1.0e-04
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
attn_output_weights = attn_output_weights.masked_fill(
|
||||
mask.unsqueeze(1), float("-inf")
|
||||
|
Loading…
x
Reference in New Issue
Block a user