diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 143690c3b..f68051938 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -1878,7 +1878,7 @@ class MulForDropout3(torch.autograd.Function): # grad and is zero-or-one. @staticmethod @custom_fwd - def forward(self, ctx, x, y, alpha): + def forward(ctx, x, y, alpha): assert not y.requires_grad ans = x * y * alpha ctx.save_for_backward(ans) @@ -1887,7 +1887,7 @@ class MulForDropout3(torch.autograd.Function): @staticmethod @custom_bwd - def backward(self, ctx, ans_grad): + def backward(ctx, ans_grad): ans, = ctx.saved_tensors x_grad = ctx.alpha * ans_grad * (ans != 0) return x_grad, None, None