Fix potential/theoretical issue in backward of LimitParamValue
This commit is contained in:
parent
d1df919547
commit
6ea1706e11
@ -728,7 +728,7 @@ class LimitParamValue(torch.autograd.Function):
|
|||||||
x, = ctx.saved_tensors
|
x, = ctx.saved_tensors
|
||||||
# where x < ctx.min, ensure all grads are negative (this will tend to make
|
# where x < ctx.min, ensure all grads are negative (this will tend to make
|
||||||
# x more positive).
|
# x more positive).
|
||||||
x_grad *= torch.where(torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0)
|
x_grad = x_grad * torch.where(torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0)
|
||||||
# where x > ctx.max, ensure all grads are positive (this will tend to make
|
# where x > ctx.max, ensure all grads are positive (this will tend to make
|
||||||
# x more negative).
|
# x more negative).
|
||||||
x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0)
|
x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0)
|
||||||
|
|||||||
Reference in New Issue
Block a user