From ff6431ed0fb8af160d6ef7048ab5c15643ba3325 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 14 Nov 2022 16:02:38 +0800 Subject: [PATCH] Implement limits on parameter values a different way. --- .../pruned_transducer_stateless7/scaling.py | 73 ++++++++----------- .../pruned_transducer_stateless7/zipformer.py | 21 +++--- 2 files changed, 42 insertions(+), 52 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 901e1ec84..cf2f7f3aa 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -158,47 +158,6 @@ class ActivationScaleBalancerFunction(torch.autograd.Function): return x_grad - neg_delta_grad, None, None, None, -class RandomClampFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - min: Optional[float], - max: Optional[float], - prob: float, - reflect: float) -> Tensor: - kwargs = {} - if min is not None: - kwargs['min'] = min - if max is not None: - kwargs['max'] = max - x_clamped = torch.clamp(x, **kwargs) - mask = torch.rand_like(x) < prob - ans = torch.where(mask, x_clamped, x) - if x.requires_grad: - ctx.save_for_backward(ans == x) - ctx.reflect = reflect - if reflect != 0.0: - ans = ans * (1.0 + reflect) - (x * reflect) - return ans - - @staticmethod - def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: - is_same, = ctx.saved_tensors - x_grad = ans_grad * is_same.to(ans_grad.dtype) - reflect = ctx.reflect - if reflect != 0.0: - x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect) - return x_grad, None, None, None, None - -def random_clamp(x: Tensor, - min: Optional[float] = None, - max: Optional[float] = None, - prob: float = 0.5, - reflect: float = 0.0): - return RandomClampFunction.apply(x, min, max, prob, reflect) - - def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: """ @@ -757,6 +716,38 @@ def with_loss(x, y): return WithLoss.apply(x, y) +class LimitParamValue(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, min: float, max: float): + ctx.save_for_backward(x) + ctx.min = min + ctx.max = max + return x + @staticmethod + def backward(ctx, x_grad: Tensor): + x, = ctx.saved_tensors + # where x < ctx.min, ensure all grads are negative (this will tend to make + # x more positive). + x_grad *= torch.where(torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0) + # where x > ctx.max, ensure all grads are positive (this will tend to make + # x more negative). + x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0) + return x_grad, None, None + +def limit_param_value(x: Tensor, + min: float, max: float, + prob: float = 0.2): + # You apply this to (typically) an nn.Parameter during training to ensure that its + # (elements mostly) stays within a supplied range. This is done by modifying the + # gradients in backprop. + # It's not necessary to do this on every batch: do it only some of the time, + # to save a little time. + if random.random() < prob: + return LimitParamValue.apply(x, min, max) + else: + return x + + def _no_op(x: Tensor) -> Tensor: if (torch.jit.is_scripting()): return x diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 92c9e2bce..232c0d4c8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -34,11 +34,11 @@ from scaling import ( Whiten, Identity, _diag, - random_clamp, penalize_abs_values_gt, softmax, ScheduledFloat, FloatLike, + limit_param_value, ) from torch import Tensor, nn @@ -435,13 +435,12 @@ class ZipformerEncoderLayer(nn.Module): grad_scale=0.01) def get_bypass_scale(self): - if torch.jit.is_scripting() or not self.training or random.random() < 0.5: - # the random.random() part is to ensure we get grads if self.bypass_scale becomes out of range + if torch.jit.is_scripting() or not self.training: return self.bypass_scale - - return self.bypass_scale.clamp(min=float(self.bypass_clamp_min), - max=float(self.bypass_clamp_max)) - + else: + return limit_param_value(self.bypass_scale, + min=float(self.bypass_clamp_min), + max=float(self.bypass_clamp_max)) def forward( self, @@ -860,10 +859,10 @@ class SimpleCombiner(torch.nn.Module): weight1 = self.weight1 - if self.training and random.random() < 0.5 and self.min_weight != (0., 0.): - weight1 = weight1.clamp(min=self.min_weight[0], - max=1.0-self.min_weight[1]) - + if self.training: + weight1 = limit_param_value(weight1, + min=self.min_weight[0], + max=1.0-self.min_weight[1]) src1_dim = src1.shape[-1] src2_dim = src2.shape[-1]