Bug fix in caching_eval

This commit is contained in:
Daniel Povey 2022-12-15 23:24:36 +08:00
parent d26ee2bf81
commit f5d4fb092d

View File

@ -265,7 +265,7 @@ class CachingEvalFunction(torch.autograd.Function):
y2 = m(x) y2 = m(x)
assert torch.allclose(y, y2, atol=1.0e-02) assert torch.allclose(y, y2, atol=1.0e-02)
# this call to backward() should create grads in the module's parameters # this call to backward() should create grads in the module's parameters
y.backward(gradient=y_grad) y2.backward(gradient=y_grad)
# restore the state from before we entered this function # restore the state from before we entered this function
random.setstate(random_state) random.setstate(random_state)