mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix SoftmaxFunction bug
This commit is contained in:
parent
137ac513bf
commit
79f1863a1e
@ -274,33 +274,43 @@ class SoftmaxFunction(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, ans_grad: Tensor):
|
def backward(ctx, ans_grad: Tensor):
|
||||||
ans, = ctx.saved_tensors
|
ans, = ctx.saved_tensors
|
||||||
try:
|
|
||||||
|
def get_x_grad(y, y_grad, dim):
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
if ans.dtype != torch.float32:
|
y_grad = y_grad.to(torch.float32)
|
||||||
ans = ans.to(torch.float32)
|
y = y.to(torch.float32)
|
||||||
x_grad = ans.mul_(ans_grad.to(torch.float32))
|
x_grad = y_grad * y
|
||||||
else:
|
y *= x_grad.sum(dim=dim, keepdim=True)
|
||||||
# out-of-place since it's not a copy
|
x_grad -= y
|
||||||
x_grad = ans_grad.to(torch.float32) * ans
|
return x_grad
|
||||||
ans *= x_grad.sum(dim=ctx.dim, keepdim=True)
|
|
||||||
x_grad -= ans
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
if __name__ == '__main__':
|
||||||
|
raise RuntimeError("For testing")
|
||||||
|
x_grad = get_x_grad(ans, ans_grad, ctx.dim)
|
||||||
return x_grad, None
|
return x_grad, None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.info(f"Caught exception in SoftmaxFunction backward: {e}, size={list(ans.shape)}, dim={ctx.dim}, will try in half precision.")
|
logging.info(f"Caught exception in SoftmaxFunction backward: {e}, size={list(ans.shape)}, dim={ctx.dim}, will try another method.")
|
||||||
x_grad = None
|
|
||||||
|
|
||||||
|
dim = ctx.dim
|
||||||
|
if dim < 0:
|
||||||
|
dim = dim + ans.dim
|
||||||
|
split_dim = 0 if dim != 0 else 1
|
||||||
|
# split_dim is the dimension we split up ans on.
|
||||||
|
num_split = min(8, ans.shape[split_dim])
|
||||||
|
x_grad_split = [
|
||||||
|
get_x_grad(ans_part, ans_grad_part, ctx.dim) for ans_part, ans_grad_part in
|
||||||
|
zip(torch.split(ans, num_split, dim=split_dim), torch.split(ans_grad, num_split, dim=split_dim))
|
||||||
|
]
|
||||||
|
x_grad = torch.cat(x_grad_split, dim=split_dim)
|
||||||
|
|
||||||
ans, = ctx.saved_tensors
|
|
||||||
ans_grad.mul_(ans)
|
|
||||||
x_grad = ans_grad
|
|
||||||
ans *= x_grad.sum(dim=ctx.dim, keepdim=True)
|
|
||||||
x_grad -= ans
|
|
||||||
return x_grad, None
|
return x_grad, None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def softmax(x: Tensor,
|
def softmax(x: Tensor,
|
||||||
dim: int):
|
dim: int):
|
||||||
return SoftmaxFunction.apply(x, dim)
|
return SoftmaxFunction.apply(x, dim)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user