mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix MulForDropout3
This commit is contained in:
parent
8bbcd81604
commit
1580c1c1cc
@ -1878,7 +1878,7 @@ class MulForDropout3(torch.autograd.Function):
|
||||
# grad and is zero-or-one.
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(self, ctx, x, y, alpha):
|
||||
def forward(ctx, x, y, alpha):
|
||||
assert not y.requires_grad
|
||||
ans = x * y * alpha
|
||||
ctx.save_for_backward(ans)
|
||||
@ -1887,7 +1887,7 @@ class MulForDropout3(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(self, ctx, ans_grad):
|
||||
def backward(ctx, ans_grad):
|
||||
ans, = ctx.saved_tensors
|
||||
x_grad = ctx.alpha * ans_grad * (ans != 0)
|
||||
return x_grad, None, None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user