mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Random clip attention scores to -5..5.
This commit is contained in:
parent
6b3f9e5036
commit
c3c655d0bd
@ -34,6 +34,7 @@ from scaling import (
|
||||
Whiten,
|
||||
Identity,
|
||||
_diag,
|
||||
random_clamp
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
@ -941,6 +942,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
training=self.training,
|
||||
key_padding_mask=key_padding_mask,
|
||||
attn_mask=attn_mask,
|
||||
attn_weights_max=5.0 if self.training else None,
|
||||
)
|
||||
return x, weights
|
||||
|
||||
@ -959,6 +961,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
training: bool = True,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
attn_weights_max: Optional[float] = None,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
r"""
|
||||
Args:
|
||||
@ -1108,6 +1111,13 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
|
||||
attn_output_weights = torch.matmul(q, k) + pos_weights
|
||||
|
||||
if attn_weights_max is not None:
|
||||
attn_output_weights = random_clamp(attn_output_weights,
|
||||
min=-attn_weights_max,
|
||||
max=attn_weights_max,
|
||||
prob=0.5)
|
||||
|
||||
# attn_output_weights: (batch, head, time1, time2)
|
||||
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
|
||||
@ -158,6 +158,31 @@ class ActivationScaleBalancerFunction(torch.autograd.Function):
|
||||
return x_grad - neg_delta_grad, None, None, None,
|
||||
|
||||
|
||||
class RandomClampFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
x: Tensor,
|
||||
min: Optional[float],
|
||||
max: Optional[float],
|
||||
prob: float) -> Tensor:
|
||||
x_clamped = torch.clamp(x, min=min, max=max)
|
||||
mask = torch.rand_like(x) < prob
|
||||
ans = torch.where(mask, x_clamped, x)
|
||||
if x.requires_grad:
|
||||
ctx.save_for_backward(ans == x)
|
||||
return ans
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None]:
|
||||
is_same, = ctx.saved_tensors
|
||||
return ans_grad * is_same.to(ans_grad.dtype), None, None, None
|
||||
|
||||
def random_clamp(x: Tensor,
|
||||
min: Optional[float] = None,
|
||||
max: Optional[float] = None,
|
||||
prob: float = 0.5):
|
||||
return RandomClampFunction.apply(x, min, max, prob)
|
||||
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user