From 36cb279318d85780d7e9348c6014198516ba99b4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 25 Oct 2022 12:21:22 +0800 Subject: [PATCH] More memory efficient backprop for DoubleSwish. --- .../pruned_transducer_stateless7/scaling.py | 49 ++++++++++++++----- 1 file changed, 38 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index f741d853c..e93b8dded 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -915,22 +915,40 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor) -> Tensor: requires_grad = x.requires_grad - x = x.detach() + x_dtype = x.dtype + if x.dtype == torch.float16: + x = x.to(torch.float32) + s = torch.sigmoid(x - 1.0) y = x * s + if requires_grad: - # discretize s. This should be expectation-preserving if we just divide the - # result by 255. - s = ((s * 255) + torch.rand_like(s)).clamp(max=255).to(torch.uint8) - ctx.save_for_backward(s, y) + deriv = (y * (1 - s) + s) + # notes on derivative of x * sigmoid(x - 1): + # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 + # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund + # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound. + # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which + # floors), should be expectation-preserving. + floor = -0.043637 + ceil = 1.2 + d_scaled = ((deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) return y @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - s, y = ctx.saved_tensors - # converts back to float. - s = s.to(y_grad.dtype) * (1.0 / 255) - return (y * (1 - s) + s) * y_grad + d, = ctx.saved_tensors + # the same constants as used in forward pass. + floor = -0.043637 + ceil = 1.2 + d = (d * ((ceil - floor) / 255.0) + floor) + return (y_grad * d) class DoubleSwish(torch.nn.Module): @@ -1073,10 +1091,19 @@ def _test_basic_norm(): def _test_double_swish_deriv(): - x = torch.randn(10, 12, dtype=torch.double) * 0.5 + x = torch.randn(10, 12, dtype=torch.double) * 3.0 x.requires_grad = True m = DoubleSwish() - torch.autograd.gradcheck(m, x, atol=0.02) + + tol = ((1.2-(-0.043637))/255.0) + torch.autograd.gradcheck(m, x, atol=tol) + + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + def _test_softmax():