mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Avoid penalize_abs_values_gt when memory usage high
This commit is contained in:
parent
815cc1ba4f
commit
625e39fd1a
@ -38,6 +38,7 @@ from scaling import (
|
||||
limit_param_value,
|
||||
clip_grad,
|
||||
convert_num_channels,
|
||||
AbsValuePenalizer,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
@ -1380,7 +1381,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
self.pos_dim = pos_dim
|
||||
self.dropout = dropout
|
||||
self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
|
||||
self.name = None # will be overwritten in training code; for diagnostics.
|
||||
self.score_penalty = AbsValuePenalizer(
|
||||
limit=25.0, penalty=1.0e-04, prob=0.1)
|
||||
|
||||
key_head_dim = query_head_dim
|
||||
in_proj_dim = (query_head_dim + key_head_dim + pos_dim) * num_heads
|
||||
@ -1484,23 +1486,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
# pos_scores: (num_heads, batch_size, tgt_seq_len, src_seq_len)
|
||||
attn_scores = attn_scores + pos_scores
|
||||
|
||||
if self.training and random.random() < 0.1:
|
||||
# This is away of limiting the attention scores to not be
|
||||
# too large. It incurs a penalty if any of them has an absolute
|
||||
# value greater than 25.0. this should be outside the normal range
|
||||
# of the attention scores. We use this mechanism instead of, say,
|
||||
# something added to the loss function involving the entropy,
|
||||
# because once the entropy gets very small gradients through the
|
||||
# softmax can become very small, and we'd get zero derivatives. The
|
||||
# choices of 1.0e-04 as the scale on the penalty makes this
|
||||
# mechanism vulnerable to the absolute scale of the loss function,
|
||||
# but we view this as a failsafe to avoid "implausible" parameter
|
||||
# values rather than a regularization method that should be active
|
||||
# under normal circumstances.
|
||||
attn_scores = penalize_abs_values_gt(attn_scores,
|
||||
limit=25.0,
|
||||
penalty=1.0e-04,
|
||||
name=self.name)
|
||||
attn_scores = self.score_penalty(attn_scores)
|
||||
|
||||
# attn_offset includes key-padding mask and attention-mask, plus any weights
|
||||
# from the subsampling.
|
||||
@ -1634,7 +1620,8 @@ class MultiheadAttentionWeights(nn.Module):
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.dropout = dropout
|
||||
self.name = None # will be overwritten in training code; for diagnostics.
|
||||
self.score_penalty = AbsValuePenalizer(
|
||||
limit=25.0, penalty=1.0e-04, prob=0.1)
|
||||
|
||||
|
||||
# the initial_scale is supposed to take over the "scaling" factor of
|
||||
@ -1697,23 +1684,7 @@ class MultiheadAttentionWeights(nn.Module):
|
||||
|
||||
attn_scores = torch.matmul(q, k)
|
||||
|
||||
if self.training and random.random() < 0.1:
|
||||
# This is a way of limiting the attention scores to not be
|
||||
# too large. It incurs a penalty if any of them has an absolute
|
||||
# value greater than 25.0. this should be outside the normal range
|
||||
# of the attention scores. We use this mechanism instead of, say,
|
||||
# something added to the loss function involving the entropy,
|
||||
# because once the entropy gets very small gradients through the
|
||||
# softmax can become very small, and we'd get zero derivatives. The
|
||||
# choices of 1.0e-04 as the scale on the penalty makes this
|
||||
# mechanism vulnerable to the absolute scale of the loss function,
|
||||
# but we view this as a failsafe to avoid "implausible" parameter
|
||||
# values rather than a regularization method that should be active
|
||||
# under normal circumstances.
|
||||
attn_scores = penalize_abs_values_gt(attn_scores,
|
||||
limit=25.0,
|
||||
penalty=1.0e-04,
|
||||
name=self.name)
|
||||
attn_scores = self.score_penalty(attn_scores)
|
||||
|
||||
assert attn_scores.shape == (num_heads, batch_size, query_len, key_len)
|
||||
|
||||
|
@ -845,6 +845,40 @@ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float,
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class AbsValuePenalizer(nn.Module):
|
||||
"""
|
||||
This module adds a penalty to the loss function when ever the absolute value of
|
||||
any element of the input tensor exceeds a certain limit.
|
||||
"""
|
||||
def __init__(self,
|
||||
limit: float,
|
||||
prob: float = 0.1,
|
||||
penalty: float = 1.0e-04):
|
||||
super().__init__()
|
||||
self.limit = limit
|
||||
self.penalty = penalty
|
||||
|
||||
self.prob = prob
|
||||
self.name = None # will be set in training loop
|
||||
|
||||
# 20% of the time we will return and do nothing because memory usage is
|
||||
# too high.
|
||||
self.mem_cutoff = CutoffEstimator(0.2)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if (torch.jit.is_scripting() or not x.requires_grad or
|
||||
(x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))
|
||||
or random.random() > self.prob):
|
||||
return _no_op(x) # the _no_op op is to make our diagnostics code work.
|
||||
|
||||
x = penalize_abs_values_gt(x,
|
||||
limit=self.limit,
|
||||
penalty=self.penalty,
|
||||
name=self.name)
|
||||
return x
|
||||
|
||||
|
||||
def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
|
||||
if x.ndim == 2:
|
||||
return x.diag()
|
||||
|
Loading…
x
Reference in New Issue
Block a user