Remove potentially wrong typing info

This commit is contained in:
Daniel Povey 2022-12-15 23:47:41 +08:00
parent 6caaa4e9c6
commit 8d9301e225

View File

@ -252,7 +252,7 @@ class CachingEvalFunction(torch.autograd.Function):
@staticmethod
@custom_bwd
def backward(ctx, y_grad: Tensor) -> Tuple[Tensor, None]:
def backward(ctx, y_grad: Tensor):
x, y = ctx.saved_tensors
x.requires_grad = ctx.x_requires_grad
m = ctx.m # e.g. a nn.Module