From c3c655d0bdce2f71a2d3168f567cecd161e16370 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 19 Oct 2022 11:59:24 +0800 Subject: [PATCH] Random clip attention scores to -5..5. --- .../pruned_transducer_stateless7/conformer.py | 10 ++++++++ .../pruned_transducer_stateless7/scaling.py | 25 +++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index f15991b20..9f6180eab 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -34,6 +34,7 @@ from scaling import ( Whiten, Identity, _diag, + random_clamp ) from torch import Tensor, nn @@ -941,6 +942,7 @@ class RelPositionMultiheadAttention(nn.Module): training=self.training, key_padding_mask=key_padding_mask, attn_mask=attn_mask, + attn_weights_max=5.0 if self.training else None, ) return x, weights @@ -959,6 +961,7 @@ class RelPositionMultiheadAttention(nn.Module): training: bool = True, key_padding_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None, + attn_weights_max: Optional[float] = None, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: @@ -1108,6 +1111,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output_weights = torch.matmul(q, k) + pos_weights + + if attn_weights_max is not None: + attn_output_weights = random_clamp(attn_output_weights, + min=-attn_weights_max, + max=attn_weights_max, + prob=0.5) + # attn_output_weights: (batch, head, time1, time2) attn_output_weights = attn_output_weights.view( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index d65b5659a..87626b780 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -158,6 +158,31 @@ class ActivationScaleBalancerFunction(torch.autograd.Function): return x_grad - neg_delta_grad, None, None, None, +class RandomClampFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + min: Optional[float], + max: Optional[float], + prob: float) -> Tensor: + x_clamped = torch.clamp(x, min=min, max=max) + mask = torch.rand_like(x) < prob + ans = torch.where(mask, x_clamped, x) + if x.requires_grad: + ctx.save_for_backward(ans == x) + return ans + + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None]: + is_same, = ctx.saved_tensors + return ans_grad * is_same.to(ans_grad.dtype), None, None, None + +def random_clamp(x: Tensor, + min: Optional[float] = None, + max: Optional[float] = None, + prob: float = 0.5): + return RandomClampFunction.apply(x, min, max, prob)