From 09cbc9fdab51e009e54c3496e7beba8b19ce0f21 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 23 Oct 2022 16:59:43 +0800 Subject: [PATCH] Save some memory in the autograd of DoubleSwish. --- .../ASR/pruned_transducer_stateless7/scaling.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 79ed592da..3d1cd5f56 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -918,12 +918,17 @@ class DoubleSwishFunction(torch.autograd.Function): x = x.detach() s = torch.sigmoid(x - 1.0) y = x * s + # discretize s. Note: .to(torch.uint8) rounds down. We'll correct for this + # in an amortized way by adding 0.5 when we convert back to float. + s = (s * 255.999).to(torch.uint8) ctx.save_for_backward(s, y) return y @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: s, y = ctx.saved_tensors + # converts back to float. + s = (s.to(y_grad.dtype) + 0.5) * (1.0 / 255.999) return (y * (1 - s) + s) * y_grad @@ -1070,7 +1075,7 @@ def _test_double_swish_deriv(): x = torch.randn(10, 12, dtype=torch.double) * 0.5 x.requires_grad = True m = DoubleSwish() - torch.autograd.gradcheck(m, x) + torch.autograd.gradcheck(m, x, atol=0.01) def _test_softmax():