Add hard limit of attention weights to +- 50

This commit is contained in:
Daniel Povey 2022-10-20 14:27:55 +08:00
parent c3c655d0bd
commit 4565d43d5c
2 changed files with 40 additions and 3 deletions

View File

@ -34,7 +34,8 @@ from scaling import (
Whiten, Whiten,
Identity, Identity,
_diag, _diag,
random_clamp random_clamp,
with_loss,
) )
from torch import Tensor, nn from torch import Tensor, nn
@ -1110,16 +1111,37 @@ class RelPositionMultiheadAttention(nn.Module):
storage_offset=pos_weights.stride(3) * (seq_len - 1)) storage_offset=pos_weights.stride(3) * (seq_len - 1))
# caution: they are really scores at this point.
attn_output_weights = torch.matmul(q, k) + pos_weights attn_output_weights = torch.matmul(q, k) + pos_weights
# The following is a soft way of encouraging the attention scores to not be too large;
# in training time, once they get outside a certain range, -5.0..5.0 currently, we
# randomly either leave them as-is or truncate them to that range.
if attn_weights_max is not None: if attn_weights_max is not None:
attn_output_weights = random_clamp(attn_output_weights, attn_output_weights = random_clamp(attn_output_weights,
min=-attn_weights_max, min=-attn_weights_max,
max=attn_weights_max, max=attn_weights_max,
prob=0.5) prob=0.5)
# attn_output_weights: (batch, head, time1, time2) if training and random.random() < 0.1:
# This is a harder way of limiting the attention scores to not be too large.
# It incurs a penalty if any of them has an absolute value greater than 50.0.
# this should be outside the normal range of the attention scores. We use
# this mechanism instead of, say, a limit on entropy, because once the entropy
# gets very small gradients through the softmax can become very small, and
# some mechanisms like that become ineffective.
attn_weights_limit = 50.0
# caution: this penalty will be affected by grad-scaling in amp.
# It's OK; this is just an emergency brake, and under normal
# conditions it shouldn't be active
attn_weights_penalty = 1.0e-04
aux_loss = attn_weights_penalty * (attn_output_weights.abs() -
attn_weights_limit).relu()
attn_output_weights = with_loss(attn_output_weights,
aux_loss)
# attn_output_weights: (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view( attn_output_weights = attn_output_weights.view(
bsz * num_heads, seq_len, seq_len bsz * num_heads, seq_len, seq_len
) )

View File

@ -608,6 +608,22 @@ class Whiten(nn.Module):
self.whitening_limit, self.whitening_limit,
self.grad_scale) self.grad_scale)
class WithLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor, y: Tensor):
ctx.y_shape = y.shape
return x
@staticmethod
def backward(ctx, ans_grad: Tensor):
return ans_grad, torch.ones(ctx.y_shape,
dtype=ans_grad.dtype,
device=ans_grad.device)
def with_loss(x, y):
# returns x but adds y.sum() to the loss function.
return WithLoss.apply(x, y)
def _no_op(x: Tensor) -> Tensor: def _no_op(x: Tensor) -> Tensor:
if (torch.jit.is_scripting()): if (torch.jit.is_scripting()):
return x return x
@ -617,7 +633,6 @@ def _no_op(x: Tensor) -> Tensor:
return x.chunk(1, dim=-1)[0] return x.chunk(1, dim=-1)[0]
class Identity(torch.nn.Module): class Identity(torch.nn.Module):
def __init__(self): def __init__(self):
super(Identity, self).__init__() super(Identity, self).__init__()