fix scaling converter test for decoder(predictor). (#553)

This commit is contained in:
kobenaxie 2022-08-27 17:26:21 +08:00 committed by GitHub
parent 2636a3dd58
commit 235eb0746f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -181,7 +181,7 @@ def test_convert_scaled_to_non_scaled():
y = torch.randint(low=1, high=vocab_size - 1, size=(N, U)) y = torch.randint(low=1, high=vocab_size - 1, size=(N, U))
d1 = model.decoder(y) d1 = model.decoder(y)
d2 = model.decoder(y) d2 = converted_model.decoder(y)
assert torch.allclose(d1, d2) assert torch.allclose(d1, d2)