Get the randomized backprop for softmax in autocast mode working.

This commit is contained in:
Daniel Povey 2022-10-20 13:23:48 +08:00
parent d75d646dc4
commit d137118484

View File

@ -219,30 +219,32 @@ class SoftmaxFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor, dim: int):
ans = x.softmax(dim=dim)
# if x dtype is float16, x.softmax() returns a float32 because
# (presumably) that op does not support float16, and autocast
# is enabled.
ctx.save_for_backward(ans)
ctx.x_dtype = x.dtype
ctx.dim = dim
return ans
@staticmethod
def backward(ctx, ans_grad: Tensor):
ans, = ctx.saved_tensors
if ans.dtype == torch.float16 or ans_grad.dtype == torch.float16:
# use a randomized approach to convert to float16
with torch.cuda.amp.autocast(enabled=False):
ans_grad = ans_grad.to(torch.float32)
ans = ans.to(torch.float32)
x_grad = ans_grad * ans
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
return random_cast_to_half(x_grad), None
else:
with torch.cuda.amp.autocast(enabled=False):
ans_grad = ans_grad.to(torch.float32)
ans = ans.to(torch.float32)
x_grad = ans_grad * ans
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
if ctx.x_dtype == torch.float16:
x_grad = random_cast_to_half(x_grad)
return x_grad, None
def softmax(x: Tensor,
dim: int):
logging.info(f"torch.is_autocast_enabled()={torch.is_autocast_enabled()}, x dtype={x.dtype}")
return SoftmaxFunction.apply(x, dim)
@ -867,7 +869,6 @@ class DoubleSwish(torch.nn.Module):
def _test_max_eig():
for proportion in [0.1, 0.5, 10.0]:
logging.info(f"proportion = {proportion}")
x = torch.randn(100, 128)
@ -891,7 +892,7 @@ def _test_max_eig():
y.backward(gradient=y_grad)
if proportion < 0.2:
assert torch.allclose(x.grad, y_grad)
assert torch.allclose(x.grad, y_grad, atol=1.0e-02)
elif proportion > 1.0:
assert not torch.allclose(x.grad, y_grad)