Random clip attention scores to -5..5.

This commit is contained in:
Daniel Povey 2022-10-19 11:59:24 +08:00
parent 6b3f9e5036
commit c3c655d0bd
2 changed files with 35 additions and 0 deletions

View File

@ -34,6 +34,7 @@ from scaling import (
Whiten, Whiten,
Identity, Identity,
_diag, _diag,
random_clamp
) )
from torch import Tensor, nn from torch import Tensor, nn
@ -941,6 +942,7 @@ class RelPositionMultiheadAttention(nn.Module):
training=self.training, training=self.training,
key_padding_mask=key_padding_mask, key_padding_mask=key_padding_mask,
attn_mask=attn_mask, attn_mask=attn_mask,
attn_weights_max=5.0 if self.training else None,
) )
return x, weights return x, weights
@ -959,6 +961,7 @@ class RelPositionMultiheadAttention(nn.Module):
training: bool = True, training: bool = True,
key_padding_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None,
attn_weights_max: Optional[float] = None,
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:
r""" r"""
Args: Args:
@ -1108,6 +1111,13 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output_weights = torch.matmul(q, k) + pos_weights attn_output_weights = torch.matmul(q, k) + pos_weights
if attn_weights_max is not None:
attn_output_weights = random_clamp(attn_output_weights,
min=-attn_weights_max,
max=attn_weights_max,
prob=0.5)
# attn_output_weights: (batch, head, time1, time2) # attn_output_weights: (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view( attn_output_weights = attn_output_weights.view(

View File

@ -158,6 +158,31 @@ class ActivationScaleBalancerFunction(torch.autograd.Function):
return x_grad - neg_delta_grad, None, None, None, return x_grad - neg_delta_grad, None, None, None,
class RandomClampFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x: Tensor,
min: Optional[float],
max: Optional[float],
prob: float) -> Tensor:
x_clamped = torch.clamp(x, min=min, max=max)
mask = torch.rand_like(x) < prob
ans = torch.where(mask, x_clamped, x)
if x.requires_grad:
ctx.save_for_backward(ans == x)
return ans
@staticmethod
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None]:
is_same, = ctx.saved_tensors
return ans_grad * is_same.to(ans_grad.dtype), None, None, None
def random_clamp(x: Tensor,
min: Optional[float] = None,
max: Optional[float] = None,
prob: float = 0.5):
return RandomClampFunction.apply(x, min, max, prob)