diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 51a0f99e8..13d2d890e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -252,7 +252,7 @@ class CachingEvalFunction(torch.autograd.Function): @staticmethod @custom_bwd - def backward(ctx, y_grad: Tensor) -> Tuple[Tensor, None]: + def backward(ctx, y_grad: Tensor): x, y = ctx.saved_tensors x.requires_grad = ctx.x_requires_grad m = ctx.m # e.g. a nn.Module