Fix bug in backprop of random_clamp()

This commit is contained in:
Daniel Povey 2022-10-20 12:49:29 +08:00
parent d37c159174
commit f6b8f0f631

View File

@ -175,7 +175,6 @@ class RandomClampFunction(torch.autograd.Function):
ctx.reflect = reflect
if reflect != 0.0:
ans = ans * (1.0 + reflect) - (x * reflect)
return ans
@staticmethod
@ -185,7 +184,7 @@ class RandomClampFunction(torch.autograd.Function):
reflect = ctx.reflect
if reflect != 0.0:
x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
return ans_grad * is_same.to(ans_grad.dtype), None, None, None, None
return x_grad, None, None, None, None
def random_clamp(x: Tensor,
min: Optional[float] = None,