mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix SoftmaxFunction bug
This commit is contained in:
parent
137ac513bf
commit
79f1863a1e
@ -274,33 +274,43 @@ class SoftmaxFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, ans_grad: Tensor):
|
||||
ans, = ctx.saved_tensors
|
||||
try:
|
||||
|
||||
def get_x_grad(y, y_grad, dim):
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
if ans.dtype != torch.float32:
|
||||
ans = ans.to(torch.float32)
|
||||
x_grad = ans.mul_(ans_grad.to(torch.float32))
|
||||
else:
|
||||
# out-of-place since it's not a copy
|
||||
x_grad = ans_grad.to(torch.float32) * ans
|
||||
ans *= x_grad.sum(dim=ctx.dim, keepdim=True)
|
||||
x_grad -= ans
|
||||
return x_grad, None
|
||||
y_grad = y_grad.to(torch.float32)
|
||||
y = y.to(torch.float32)
|
||||
x_grad = y_grad * y
|
||||
y *= x_grad.sum(dim=dim, keepdim=True)
|
||||
x_grad -= y
|
||||
return x_grad
|
||||
|
||||
|
||||
|
||||
try:
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("For testing")
|
||||
x_grad = get_x_grad(ans, ans_grad, ctx.dim)
|
||||
return x_grad, None
|
||||
except Exception as e:
|
||||
logging.info(f"Caught exception in SoftmaxFunction backward: {e}, size={list(ans.shape)}, dim={ctx.dim}, will try in half precision.")
|
||||
x_grad = None
|
||||
logging.info(f"Caught exception in SoftmaxFunction backward: {e}, size={list(ans.shape)}, dim={ctx.dim}, will try another method.")
|
||||
|
||||
dim = ctx.dim
|
||||
if dim < 0:
|
||||
dim = dim + ans.dim
|
||||
split_dim = 0 if dim != 0 else 1
|
||||
# split_dim is the dimension we split up ans on.
|
||||
num_split = min(8, ans.shape[split_dim])
|
||||
x_grad_split = [
|
||||
get_x_grad(ans_part, ans_grad_part, ctx.dim) for ans_part, ans_grad_part in
|
||||
zip(torch.split(ans, num_split, dim=split_dim), torch.split(ans_grad, num_split, dim=split_dim))
|
||||
]
|
||||
x_grad = torch.cat(x_grad_split, dim=split_dim)
|
||||
|
||||
ans, = ctx.saved_tensors
|
||||
ans_grad.mul_(ans)
|
||||
x_grad = ans_grad
|
||||
ans *= x_grad.sum(dim=ctx.dim, keepdim=True)
|
||||
x_grad -= ans
|
||||
return x_grad, None
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def softmax(x: Tensor,
|
||||
dim: int):
|
||||
return SoftmaxFunction.apply(x, dim)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user