diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index e31548737..be6f94412 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -915,13 +915,15 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad x = x.detach() s = torch.sigmoid(x - 1.0) y = x * s - # discretize s. This should be expectation-preserving if we just divide the - # result by 255. - s = ((s * 255) + torch.randn_like(s)).to(torch.uint8) - ctx.save_for_backward(s, y) + if requires_grad: + # discretize s. This should be expectation-preserving if we just divide the + # result by 255. + s = ((s * 255) + torch.rand_like(s)).to(torch.uint8) + ctx.save_for_backward(s, y) return y @staticmethod