From 625e39fd1a128bc282be96cfe54c207cdeac2f42 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 28 May 2023 20:40:47 +0800 Subject: [PATCH] Avoid penalize_abs_values_gt when memory usage high --- egs/libriheavy/LM/zipformer1/subformer.py | 43 +++---------------- .../pruned_transducer_stateless7/scaling.py | 34 +++++++++++++++ 2 files changed, 41 insertions(+), 36 deletions(-) diff --git a/egs/libriheavy/LM/zipformer1/subformer.py b/egs/libriheavy/LM/zipformer1/subformer.py index ae589813d..6813c0aca 100644 --- a/egs/libriheavy/LM/zipformer1/subformer.py +++ b/egs/libriheavy/LM/zipformer1/subformer.py @@ -38,6 +38,7 @@ from scaling import ( limit_param_value, clip_grad, convert_num_channels, + AbsValuePenalizer, ) from torch import Tensor, nn @@ -1380,7 +1381,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module): self.pos_dim = pos_dim self.dropout = dropout self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) - self.name = None # will be overwritten in training code; for diagnostics. + self.score_penalty = AbsValuePenalizer( + limit=25.0, penalty=1.0e-04, prob=0.1) key_head_dim = query_head_dim in_proj_dim = (query_head_dim + key_head_dim + pos_dim) * num_heads @@ -1484,23 +1486,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # pos_scores: (num_heads, batch_size, tgt_seq_len, src_seq_len) attn_scores = attn_scores + pos_scores - if self.training and random.random() < 0.1: - # This is away of limiting the attention scores to not be - # too large. It incurs a penalty if any of them has an absolute - # value greater than 25.0. this should be outside the normal range - # of the attention scores. We use this mechanism instead of, say, - # something added to the loss function involving the entropy, - # because once the entropy gets very small gradients through the - # softmax can become very small, and we'd get zero derivatives. The - # choices of 1.0e-04 as the scale on the penalty makes this - # mechanism vulnerable to the absolute scale of the loss function, - # but we view this as a failsafe to avoid "implausible" parameter - # values rather than a regularization method that should be active - # under normal circumstances. - attn_scores = penalize_abs_values_gt(attn_scores, - limit=25.0, - penalty=1.0e-04, - name=self.name) + attn_scores = self.score_penalty(attn_scores) # attn_offset includes key-padding mask and attention-mask, plus any weights # from the subsampling. @@ -1634,7 +1620,8 @@ class MultiheadAttentionWeights(nn.Module): self.num_heads = num_heads self.head_dim = head_dim self.dropout = dropout - self.name = None # will be overwritten in training code; for diagnostics. + self.score_penalty = AbsValuePenalizer( + limit=25.0, penalty=1.0e-04, prob=0.1) # the initial_scale is supposed to take over the "scaling" factor of @@ -1697,23 +1684,7 @@ class MultiheadAttentionWeights(nn.Module): attn_scores = torch.matmul(q, k) - if self.training and random.random() < 0.1: - # This is a way of limiting the attention scores to not be - # too large. It incurs a penalty if any of them has an absolute - # value greater than 25.0. this should be outside the normal range - # of the attention scores. We use this mechanism instead of, say, - # something added to the loss function involving the entropy, - # because once the entropy gets very small gradients through the - # softmax can become very small, and we'd get zero derivatives. The - # choices of 1.0e-04 as the scale on the penalty makes this - # mechanism vulnerable to the absolute scale of the loss function, - # but we view this as a failsafe to avoid "implausible" parameter - # values rather than a regularization method that should be active - # under normal circumstances. - attn_scores = penalize_abs_values_gt(attn_scores, - limit=25.0, - penalty=1.0e-04, - name=self.name) + attn_scores = self.score_penalty(attn_scores) assert attn_scores.shape == (num_heads, batch_size, query_len, key_len) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 828dd0fed..210bc1e2f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -845,6 +845,40 @@ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float, return x + +class AbsValuePenalizer(nn.Module): + """ + This module adds a penalty to the loss function when ever the absolute value of + any element of the input tensor exceeds a certain limit. + """ + def __init__(self, + limit: float, + prob: float = 0.1, + penalty: float = 1.0e-04): + super().__init__() + self.limit = limit + self.penalty = penalty + + self.prob = prob + self.name = None # will be set in training loop + + # 20% of the time we will return and do nothing because memory usage is + # too high. + self.mem_cutoff = CutoffEstimator(0.2) + + def forward(self, x: Tensor) -> Tensor: + if (torch.jit.is_scripting() or not x.requires_grad or + (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated())) + or random.random() > self.prob): + return _no_op(x) # the _no_op op is to make our diagnostics code work. + + x = penalize_abs_values_gt(x, + limit=self.limit, + penalty=self.penalty, + name=self.name) + return x + + def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. if x.ndim == 2: return x.diag()