diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 19e8e6fa8..31c389461 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -175,7 +175,6 @@ class RandomClampFunction(torch.autograd.Function): ctx.reflect = reflect if reflect != 0.0: ans = ans * (1.0 + reflect) - (x * reflect) - return ans @staticmethod @@ -185,7 +184,7 @@ class RandomClampFunction(torch.autograd.Function): reflect = ctx.reflect if reflect != 0.0: x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect) - return ans_grad * is_same.to(ans_grad.dtype), None, None, None, None + return x_grad, None, None, None, None def random_clamp(x: Tensor, min: Optional[float] = None,