Fix MulForDropout3

This commit is contained in:
Daniel Povey 2023-01-11 12:26:41 +08:00
parent 8bbcd81604
commit 1580c1c1cc

View File

@ -1878,7 +1878,7 @@ class MulForDropout3(torch.autograd.Function):
# grad and is zero-or-one.
@staticmethod
@custom_fwd
def forward(self, ctx, x, y, alpha):
def forward(ctx, x, y, alpha):
assert not y.requires_grad
ans = x * y * alpha
ctx.save_for_backward(ans)
@ -1887,7 +1887,7 @@ class MulForDropout3(torch.autograd.Function):
@staticmethod
@custom_bwd
def backward(self, ctx, ans_grad):
def backward(ctx, ans_grad):
ans, = ctx.saved_tensors
x_grad = ctx.alpha * ans_grad * (ans != 0)
return x_grad, None, None