mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix bug in backprop of random_clamp()
This commit is contained in:
parent
d37c159174
commit
f6b8f0f631
@ -175,7 +175,6 @@ class RandomClampFunction(torch.autograd.Function):
|
||||
ctx.reflect = reflect
|
||||
if reflect != 0.0:
|
||||
ans = ans * (1.0 + reflect) - (x * reflect)
|
||||
|
||||
return ans
|
||||
|
||||
@staticmethod
|
||||
@ -185,7 +184,7 @@ class RandomClampFunction(torch.autograd.Function):
|
||||
reflect = ctx.reflect
|
||||
if reflect != 0.0:
|
||||
x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
|
||||
return ans_grad * is_same.to(ans_grad.dtype), None, None, None, None
|
||||
return x_grad, None, None, None, None
|
||||
|
||||
def random_clamp(x: Tensor,
|
||||
min: Optional[float] = None,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user