diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 4febd2034..0b51057cf 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -265,7 +265,7 @@ class CachingEvalFunction(torch.autograd.Function): y2 = m(x) assert torch.allclose(y, y2, atol=1.0e-02) # this call to backward() should create grads in the module's parameters - y.backward(gradient=y_grad) + y2.backward(gradient=y_grad) # restore the state from before we entered this function random.setstate(random_state)