Use full precision to do softmax and store ans.

This commit is contained in:
Daniel Povey 2022-10-19 19:53:53 +08:00
parent 0ad4462632
commit cc15552510

View File

@ -219,10 +219,13 @@ class SoftmaxFunction(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, x: Tensor, dim: int):
ans = x.softmax(dim=dim)
ctx.save_for_backward(ans)
ctx.dim = dim
return ans
with torch.cuda.amp.autocast(enabled=False):
if x.dtype == torch.float16:
x = x.to(torch.float32)
ans = x.softmax(dim=dim)
ctx.save_for_backward(ans)
ctx.dim = dim
return ans
@staticmethod
def backward(ctx, ans_grad: Tensor):