From d137118484034887da485fc4abaf5a488345b488 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 20 Oct 2022 13:23:48 +0800 Subject: [PATCH] Get the randomized backprop for softmax in autocast mode working. --- .../pruned_transducer_stateless7/scaling.py | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 280153137..f819bdf7c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -219,30 +219,32 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, dim: int): ans = x.softmax(dim=dim) + # if x dtype is float16, x.softmax() returns a float32 because + # (presumably) that op does not support float16, and autocast + # is enabled. ctx.save_for_backward(ans) + ctx.x_dtype = x.dtype ctx.dim = dim return ans @staticmethod def backward(ctx, ans_grad: Tensor): ans, = ctx.saved_tensors - - if ans.dtype == torch.float16 or ans_grad.dtype == torch.float16: - # use a randomized approach to convert to float16 - with torch.cuda.amp.autocast(enabled=False): - ans_grad = ans_grad.to(torch.float32) - ans = ans.to(torch.float32) - x_grad = ans_grad * ans - x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) - return random_cast_to_half(x_grad), None - else: + with torch.cuda.amp.autocast(enabled=False): + ans_grad = ans_grad.to(torch.float32) + ans = ans.to(torch.float32) x_grad = ans_grad * ans x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) + if ctx.x_dtype == torch.float16: + x_grad = random_cast_to_half(x_grad) + return x_grad, None + def softmax(x: Tensor, dim: int): + logging.info(f"torch.is_autocast_enabled()={torch.is_autocast_enabled()}, x dtype={x.dtype}") return SoftmaxFunction.apply(x, dim) @@ -867,7 +869,6 @@ class DoubleSwish(torch.nn.Module): def _test_max_eig(): - for proportion in [0.1, 0.5, 10.0]: logging.info(f"proportion = {proportion}") x = torch.randn(100, 128) @@ -891,7 +892,7 @@ def _test_max_eig(): y.backward(gradient=y_grad) if proportion < 0.2: - assert torch.allclose(x.grad, y_grad) + assert torch.allclose(x.grad, y_grad, atol=1.0e-02) elif proportion > 1.0: assert not torch.allclose(x.grad, y_grad)