Fix potential/theoretical issue in backward of LimitParamValue

This commit is contained in:
Daniel Povey 2022-11-14 23:31:00 +08:00
parent d1df919547
commit 6ea1706e11

View File

@ -728,7 +728,7 @@ class LimitParamValue(torch.autograd.Function):
x, = ctx.saved_tensors
# where x < ctx.min, ensure all grads are negative (this will tend to make
# x more positive).
x_grad *= torch.where(torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0)
x_grad = x_grad * torch.where(torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0)
# where x > ctx.max, ensure all grads are positive (this will tend to make
# x more negative).
x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0)