mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add hard limit of attention weights to +- 50
This commit is contained in:
parent
c3c655d0bd
commit
4565d43d5c
@ -34,7 +34,8 @@ from scaling import (
|
|||||||
Whiten,
|
Whiten,
|
||||||
Identity,
|
Identity,
|
||||||
_diag,
|
_diag,
|
||||||
random_clamp
|
random_clamp,
|
||||||
|
with_loss,
|
||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
@ -1110,16 +1111,37 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
storage_offset=pos_weights.stride(3) * (seq_len - 1))
|
storage_offset=pos_weights.stride(3) * (seq_len - 1))
|
||||||
|
|
||||||
|
|
||||||
|
# caution: they are really scores at this point.
|
||||||
attn_output_weights = torch.matmul(q, k) + pos_weights
|
attn_output_weights = torch.matmul(q, k) + pos_weights
|
||||||
|
|
||||||
|
# The following is a soft way of encouraging the attention scores to not be too large;
|
||||||
|
# in training time, once they get outside a certain range, -5.0..5.0 currently, we
|
||||||
|
# randomly either leave them as-is or truncate them to that range.
|
||||||
if attn_weights_max is not None:
|
if attn_weights_max is not None:
|
||||||
attn_output_weights = random_clamp(attn_output_weights,
|
attn_output_weights = random_clamp(attn_output_weights,
|
||||||
min=-attn_weights_max,
|
min=-attn_weights_max,
|
||||||
max=attn_weights_max,
|
max=attn_weights_max,
|
||||||
prob=0.5)
|
prob=0.5)
|
||||||
|
|
||||||
# attn_output_weights: (batch, head, time1, time2)
|
if training and random.random() < 0.1:
|
||||||
|
# This is a harder way of limiting the attention scores to not be too large.
|
||||||
|
# It incurs a penalty if any of them has an absolute value greater than 50.0.
|
||||||
|
# this should be outside the normal range of the attention scores. We use
|
||||||
|
# this mechanism instead of, say, a limit on entropy, because once the entropy
|
||||||
|
# gets very small gradients through the softmax can become very small, and
|
||||||
|
# some mechanisms like that become ineffective.
|
||||||
|
attn_weights_limit = 50.0
|
||||||
|
# caution: this penalty will be affected by grad-scaling in amp.
|
||||||
|
# It's OK; this is just an emergency brake, and under normal
|
||||||
|
# conditions it shouldn't be active
|
||||||
|
attn_weights_penalty = 1.0e-04
|
||||||
|
aux_loss = attn_weights_penalty * (attn_output_weights.abs() -
|
||||||
|
attn_weights_limit).relu()
|
||||||
|
attn_output_weights = with_loss(attn_output_weights,
|
||||||
|
aux_loss)
|
||||||
|
|
||||||
|
|
||||||
|
# attn_output_weights: (batch, head, time1, time2)
|
||||||
attn_output_weights = attn_output_weights.view(
|
attn_output_weights = attn_output_weights.view(
|
||||||
bsz * num_heads, seq_len, seq_len
|
bsz * num_heads, seq_len, seq_len
|
||||||
)
|
)
|
||||||
|
|||||||
@ -608,6 +608,22 @@ class Whiten(nn.Module):
|
|||||||
self.whitening_limit,
|
self.whitening_limit,
|
||||||
self.grad_scale)
|
self.grad_scale)
|
||||||
|
|
||||||
|
|
||||||
|
class WithLoss(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x: Tensor, y: Tensor):
|
||||||
|
ctx.y_shape = y.shape
|
||||||
|
return x
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, ans_grad: Tensor):
|
||||||
|
return ans_grad, torch.ones(ctx.y_shape,
|
||||||
|
dtype=ans_grad.dtype,
|
||||||
|
device=ans_grad.device)
|
||||||
|
def with_loss(x, y):
|
||||||
|
# returns x but adds y.sum() to the loss function.
|
||||||
|
return WithLoss.apply(x, y)
|
||||||
|
|
||||||
|
|
||||||
def _no_op(x: Tensor) -> Tensor:
|
def _no_op(x: Tensor) -> Tensor:
|
||||||
if (torch.jit.is_scripting()):
|
if (torch.jit.is_scripting()):
|
||||||
return x
|
return x
|
||||||
@ -617,7 +633,6 @@ def _no_op(x: Tensor) -> Tensor:
|
|||||||
return x.chunk(1, dim=-1)[0]
|
return x.chunk(1, dim=-1)[0]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Identity(torch.nn.Module):
|
class Identity(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Identity, self).__init__()
|
super(Identity, self).__init__()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user