mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
apply penalize_abs_values_gt on attn_output_weights
This commit is contained in:
parent
24d6565126
commit
3364d9863c
@ -1,5 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey, Zengwei Yao)
|
# Copyright 2022 Xiaomi Corp. (authors: Zengwei Yao)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -14,6 +14,10 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
# The model structure is modified from Daniel Povey's Zipformer
|
||||||
|
# https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
@ -605,6 +609,16 @@ class MultiHeadedAttention(nn.Module):
|
|||||||
# (batch, head, time1, time2)
|
# (batch, head, time1, time2)
|
||||||
attn_output_weights = torch.matmul(q, k)
|
attn_output_weights = torch.matmul(q, k)
|
||||||
|
|
||||||
|
# This is a harder way of limiting the attention scores to not be too large.
|
||||||
|
# It incurs a penalty if any of them has an absolute value greater than 50.0.
|
||||||
|
# this should be outside the normal range of the attention scores. We use
|
||||||
|
# this mechanism instead of, say, a limit on entropy, because once the entropy
|
||||||
|
# gets very small gradients through the softmax can become very small, and
|
||||||
|
# some mechanisms like that become ineffective.
|
||||||
|
attn_output_weights = penalize_abs_values_gt(
|
||||||
|
attn_output_weights, limit=25.0, penalty=1.0e-04
|
||||||
|
)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
attn_output_weights = attn_output_weights.masked_fill(
|
attn_output_weights = attn_output_weights.masked_fill(
|
||||||
mask.unsqueeze(1), float("-inf")
|
mask.unsqueeze(1), float("-inf")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user