mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Reduce mem consumption of softmax backward
This commit is contained in:
parent
465d41c429
commit
bfeeddda81
@ -276,7 +276,8 @@ class SoftmaxFunction(torch.autograd.Function):
|
||||
ans_grad = ans_grad.to(torch.float32)
|
||||
ans = ans.to(torch.float32)
|
||||
x_grad = ans_grad * ans
|
||||
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
|
||||
ans *= x_grad.sum(dim=ctx.dim, keepdim=True)
|
||||
x_grad -= ans
|
||||
return x_grad, None
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user