mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Reduce mem consumption of softmax backward
This commit is contained in:
parent
465d41c429
commit
bfeeddda81
@ -276,7 +276,8 @@ class SoftmaxFunction(torch.autograd.Function):
|
|||||||
ans_grad = ans_grad.to(torch.float32)
|
ans_grad = ans_grad.to(torch.float32)
|
||||||
ans = ans.to(torch.float32)
|
ans = ans.to(torch.float32)
|
||||||
x_grad = ans_grad * ans
|
x_grad = ans_grad * ans
|
||||||
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
|
ans *= x_grad.sum(dim=ctx.dim, keepdim=True)
|
||||||
|
x_grad -= ans
|
||||||
return x_grad, None
|
return x_grad, None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user