mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Random clip attention scores to -5..5.
This commit is contained in:
parent
6b3f9e5036
commit
c3c655d0bd
@ -34,6 +34,7 @@ from scaling import (
|
|||||||
Whiten,
|
Whiten,
|
||||||
Identity,
|
Identity,
|
||||||
_diag,
|
_diag,
|
||||||
|
random_clamp
|
||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
@ -941,6 +942,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
training=self.training,
|
training=self.training,
|
||||||
key_padding_mask=key_padding_mask,
|
key_padding_mask=key_padding_mask,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
|
attn_weights_max=5.0 if self.training else None,
|
||||||
)
|
)
|
||||||
return x, weights
|
return x, weights
|
||||||
|
|
||||||
@ -959,6 +961,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
training: bool = True,
|
training: bool = True,
|
||||||
key_padding_mask: Optional[Tensor] = None,
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
|
attn_weights_max: Optional[float] = None,
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@ -1108,6 +1111,13 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
attn_output_weights = torch.matmul(q, k) + pos_weights
|
attn_output_weights = torch.matmul(q, k) + pos_weights
|
||||||
|
|
||||||
|
if attn_weights_max is not None:
|
||||||
|
attn_output_weights = random_clamp(attn_output_weights,
|
||||||
|
min=-attn_weights_max,
|
||||||
|
max=attn_weights_max,
|
||||||
|
prob=0.5)
|
||||||
|
|
||||||
# attn_output_weights: (batch, head, time1, time2)
|
# attn_output_weights: (batch, head, time1, time2)
|
||||||
|
|
||||||
attn_output_weights = attn_output_weights.view(
|
attn_output_weights = attn_output_weights.view(
|
||||||
|
|||||||
@ -158,6 +158,31 @@ class ActivationScaleBalancerFunction(torch.autograd.Function):
|
|||||||
return x_grad - neg_delta_grad, None, None, None,
|
return x_grad - neg_delta_grad, None, None, None,
|
||||||
|
|
||||||
|
|
||||||
|
class RandomClampFunction(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(
|
||||||
|
ctx,
|
||||||
|
x: Tensor,
|
||||||
|
min: Optional[float],
|
||||||
|
max: Optional[float],
|
||||||
|
prob: float) -> Tensor:
|
||||||
|
x_clamped = torch.clamp(x, min=min, max=max)
|
||||||
|
mask = torch.rand_like(x) < prob
|
||||||
|
ans = torch.where(mask, x_clamped, x)
|
||||||
|
if x.requires_grad:
|
||||||
|
ctx.save_for_backward(ans == x)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None]:
|
||||||
|
is_same, = ctx.saved_tensors
|
||||||
|
return ans_grad * is_same.to(ans_grad.dtype), None, None, None
|
||||||
|
|
||||||
|
def random_clamp(x: Tensor,
|
||||||
|
min: Optional[float] = None,
|
||||||
|
max: Optional[float] = None,
|
||||||
|
prob: float = 0.5):
|
||||||
|
return RandomClampFunction.apply(x, min, max, prob)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user