Avoid penalize_abs_values_gt when memory usage high

This commit is contained in:
Daniel Povey 2023-05-28 20:40:47 +08:00
parent 815cc1ba4f
commit 625e39fd1a
2 changed files with 41 additions and 36 deletions

View File

@ -38,6 +38,7 @@ from scaling import (
limit_param_value,
clip_grad,
convert_num_channels,
AbsValuePenalizer,
)
from torch import Tensor, nn
@ -1380,7 +1381,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
self.pos_dim = pos_dim
self.dropout = dropout
self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
self.name = None # will be overwritten in training code; for diagnostics.
self.score_penalty = AbsValuePenalizer(
limit=25.0, penalty=1.0e-04, prob=0.1)
key_head_dim = query_head_dim
in_proj_dim = (query_head_dim + key_head_dim + pos_dim) * num_heads
@ -1484,23 +1486,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# pos_scores: (num_heads, batch_size, tgt_seq_len, src_seq_len)
attn_scores = attn_scores + pos_scores
if self.training and random.random() < 0.1:
# This is away of limiting the attention scores to not be
# too large. It incurs a penalty if any of them has an absolute
# value greater than 25.0. this should be outside the normal range
# of the attention scores. We use this mechanism instead of, say,
# something added to the loss function involving the entropy,
# because once the entropy gets very small gradients through the
# softmax can become very small, and we'd get zero derivatives. The
# choices of 1.0e-04 as the scale on the penalty makes this
# mechanism vulnerable to the absolute scale of the loss function,
# but we view this as a failsafe to avoid "implausible" parameter
# values rather than a regularization method that should be active
# under normal circumstances.
attn_scores = penalize_abs_values_gt(attn_scores,
limit=25.0,
penalty=1.0e-04,
name=self.name)
attn_scores = self.score_penalty(attn_scores)
# attn_offset includes key-padding mask and attention-mask, plus any weights
# from the subsampling.
@ -1634,7 +1620,8 @@ class MultiheadAttentionWeights(nn.Module):
self.num_heads = num_heads
self.head_dim = head_dim
self.dropout = dropout
self.name = None # will be overwritten in training code; for diagnostics.
self.score_penalty = AbsValuePenalizer(
limit=25.0, penalty=1.0e-04, prob=0.1)
# the initial_scale is supposed to take over the "scaling" factor of
@ -1697,23 +1684,7 @@ class MultiheadAttentionWeights(nn.Module):
attn_scores = torch.matmul(q, k)
if self.training and random.random() < 0.1:
# This is a way of limiting the attention scores to not be
# too large. It incurs a penalty if any of them has an absolute
# value greater than 25.0. this should be outside the normal range
# of the attention scores. We use this mechanism instead of, say,
# something added to the loss function involving the entropy,
# because once the entropy gets very small gradients through the
# softmax can become very small, and we'd get zero derivatives. The
# choices of 1.0e-04 as the scale on the penalty makes this
# mechanism vulnerable to the absolute scale of the loss function,
# but we view this as a failsafe to avoid "implausible" parameter
# values rather than a regularization method that should be active
# under normal circumstances.
attn_scores = penalize_abs_values_gt(attn_scores,
limit=25.0,
penalty=1.0e-04,
name=self.name)
attn_scores = self.score_penalty(attn_scores)
assert attn_scores.shape == (num_heads, batch_size, query_len, key_len)

View File

@ -845,6 +845,40 @@ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float,
return x
class AbsValuePenalizer(nn.Module):
"""
This module adds a penalty to the loss function when ever the absolute value of
any element of the input tensor exceeds a certain limit.
"""
def __init__(self,
limit: float,
prob: float = 0.1,
penalty: float = 1.0e-04):
super().__init__()
self.limit = limit
self.penalty = penalty
self.prob = prob
self.name = None # will be set in training loop
# 20% of the time we will return and do nothing because memory usage is
# too high.
self.mem_cutoff = CutoffEstimator(0.2)
def forward(self, x: Tensor) -> Tensor:
if (torch.jit.is_scripting() or not x.requires_grad or
(x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))
or random.random() > self.prob):
return _no_op(x) # the _no_op op is to make our diagnostics code work.
x = penalize_abs_values_gt(x,
limit=self.limit,
penalty=self.penalty,
name=self.name)
return x
def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
if x.ndim == 2:
return x.diag()