mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix MulForDropout3
This commit is contained in:
parent
8bbcd81604
commit
1580c1c1cc
@ -1878,7 +1878,7 @@ class MulForDropout3(torch.autograd.Function):
|
|||||||
# grad and is zero-or-one.
|
# grad and is zero-or-one.
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd
|
@custom_fwd
|
||||||
def forward(self, ctx, x, y, alpha):
|
def forward(ctx, x, y, alpha):
|
||||||
assert not y.requires_grad
|
assert not y.requires_grad
|
||||||
ans = x * y * alpha
|
ans = x * y * alpha
|
||||||
ctx.save_for_backward(ans)
|
ctx.save_for_backward(ans)
|
||||||
@ -1887,7 +1887,7 @@ class MulForDropout3(torch.autograd.Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_bwd
|
@custom_bwd
|
||||||
def backward(self, ctx, ans_grad):
|
def backward(ctx, ans_grad):
|
||||||
ans, = ctx.saved_tensors
|
ans, = ctx.saved_tensors
|
||||||
x_grad = ctx.alpha * ans_grad * (ans != 0)
|
x_grad = ctx.alpha * ans_grad * (ans != 0)
|
||||||
return x_grad, None, None
|
return x_grad, None, None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user