Reduce mem consumption of softmax backward

This commit is contained in:
Daniel Povey 2023-05-16 12:18:09 +08:00
parent 465d41c429
commit bfeeddda81

View File

@ -276,7 +276,8 @@ class SoftmaxFunction(torch.autograd.Function):
ans_grad = ans_grad.to(torch.float32)
ans = ans.to(torch.float32)
x_grad = ans_grad * ans
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
ans *= x_grad.sum(dim=ctx.dim, keepdim=True)
x_grad -= ans
return x_grad, None