Fix SoftmaxFunction bug

This commit is contained in:
Daniel Povey 2023-05-29 10:55:03 +08:00
parent 137ac513bf
commit 79f1863a1e

View File

@ -274,33 +274,43 @@ class SoftmaxFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, ans_grad: Tensor):
ans, = ctx.saved_tensors
try:
def get_x_grad(y, y_grad, dim):
with torch.cuda.amp.autocast(enabled=False):
if ans.dtype != torch.float32:
ans = ans.to(torch.float32)
x_grad = ans.mul_(ans_grad.to(torch.float32))
else:
# out-of-place since it's not a copy
x_grad = ans_grad.to(torch.float32) * ans
ans *= x_grad.sum(dim=ctx.dim, keepdim=True)
x_grad -= ans
y_grad = y_grad.to(torch.float32)
y = y.to(torch.float32)
x_grad = y_grad * y
y *= x_grad.sum(dim=dim, keepdim=True)
x_grad -= y
return x_grad
try:
if __name__ == '__main__':
raise RuntimeError("For testing")
x_grad = get_x_grad(ans, ans_grad, ctx.dim)
return x_grad, None
except Exception as e:
logging.info(f"Caught exception in SoftmaxFunction backward: {e}, size={list(ans.shape)}, dim={ctx.dim}, will try in half precision.")
x_grad = None
logging.info(f"Caught exception in SoftmaxFunction backward: {e}, size={list(ans.shape)}, dim={ctx.dim}, will try another method.")
dim = ctx.dim
if dim < 0:
dim = dim + ans.dim
split_dim = 0 if dim != 0 else 1
# split_dim is the dimension we split up ans on.
num_split = min(8, ans.shape[split_dim])
x_grad_split = [
get_x_grad(ans_part, ans_grad_part, ctx.dim) for ans_part, ans_grad_part in
zip(torch.split(ans, num_split, dim=split_dim), torch.split(ans_grad, num_split, dim=split_dim))
]
x_grad = torch.cat(x_grad_split, dim=split_dim)
ans, = ctx.saved_tensors
ans_grad.mul_(ans)
x_grad = ans_grad
ans *= x_grad.sum(dim=ctx.dim, keepdim=True)
x_grad -= ans
return x_grad, None
def softmax(x: Tensor,
dim: int):
return SoftmaxFunction.apply(x, dim)