mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add hard limit of attention weights to +- 50
This commit is contained in:
parent
c3c655d0bd
commit
4565d43d5c
@ -34,7 +34,8 @@ from scaling import (
|
||||
Whiten,
|
||||
Identity,
|
||||
_diag,
|
||||
random_clamp
|
||||
random_clamp,
|
||||
with_loss,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
@ -1110,16 +1111,37 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
storage_offset=pos_weights.stride(3) * (seq_len - 1))
|
||||
|
||||
|
||||
# caution: they are really scores at this point.
|
||||
attn_output_weights = torch.matmul(q, k) + pos_weights
|
||||
|
||||
# The following is a soft way of encouraging the attention scores to not be too large;
|
||||
# in training time, once they get outside a certain range, -5.0..5.0 currently, we
|
||||
# randomly either leave them as-is or truncate them to that range.
|
||||
if attn_weights_max is not None:
|
||||
attn_output_weights = random_clamp(attn_output_weights,
|
||||
min=-attn_weights_max,
|
||||
max=attn_weights_max,
|
||||
prob=0.5)
|
||||
|
||||
# attn_output_weights: (batch, head, time1, time2)
|
||||
if training and random.random() < 0.1:
|
||||
# This is a harder way of limiting the attention scores to not be too large.
|
||||
# It incurs a penalty if any of them has an absolute value greater than 50.0.
|
||||
# this should be outside the normal range of the attention scores. We use
|
||||
# this mechanism instead of, say, a limit on entropy, because once the entropy
|
||||
# gets very small gradients through the softmax can become very small, and
|
||||
# some mechanisms like that become ineffective.
|
||||
attn_weights_limit = 50.0
|
||||
# caution: this penalty will be affected by grad-scaling in amp.
|
||||
# It's OK; this is just an emergency brake, and under normal
|
||||
# conditions it shouldn't be active
|
||||
attn_weights_penalty = 1.0e-04
|
||||
aux_loss = attn_weights_penalty * (attn_output_weights.abs() -
|
||||
attn_weights_limit).relu()
|
||||
attn_output_weights = with_loss(attn_output_weights,
|
||||
aux_loss)
|
||||
|
||||
|
||||
# attn_output_weights: (batch, head, time1, time2)
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz * num_heads, seq_len, seq_len
|
||||
)
|
||||
|
||||
@ -608,6 +608,22 @@ class Whiten(nn.Module):
|
||||
self.whitening_limit,
|
||||
self.grad_scale)
|
||||
|
||||
|
||||
class WithLoss(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor, y: Tensor):
|
||||
ctx.y_shape = y.shape
|
||||
return x
|
||||
@staticmethod
|
||||
def backward(ctx, ans_grad: Tensor):
|
||||
return ans_grad, torch.ones(ctx.y_shape,
|
||||
dtype=ans_grad.dtype,
|
||||
device=ans_grad.device)
|
||||
def with_loss(x, y):
|
||||
# returns x but adds y.sum() to the loss function.
|
||||
return WithLoss.apply(x, y)
|
||||
|
||||
|
||||
def _no_op(x: Tensor) -> Tensor:
|
||||
if (torch.jit.is_scripting()):
|
||||
return x
|
||||
@ -617,7 +633,6 @@ def _no_op(x: Tensor) -> Tensor:
|
||||
return x.chunk(1, dim=-1)[0]
|
||||
|
||||
|
||||
|
||||
class Identity(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user