Merge branch 'scaled_adam_exp151' into scaled_adam_exp150

This commit is contained in:
Daniel Povey 2022-10-19 19:59:07 +08:00
commit f08a869769

View File

@ -255,10 +255,13 @@ class SoftmaxFunction(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, x: Tensor, dim: int):
ans = x.softmax(dim=dim)
ctx.save_for_backward(ans)
ctx.dim = dim
return ans
with torch.cuda.amp.autocast(enabled=False):
if x.dtype == torch.float16:
x = x.to(torch.float32)
ans = x.softmax(dim=dim)
ctx.save_for_backward(ans)
ctx.dim = dim
return ans
@staticmethod
def backward(ctx, ans_grad: Tensor):