Use eval() for the masked LM model in decoding.

This commit is contained in:
Fangjun Kuang 2021-11-16 10:55:14 +08:00
parent 7f5d9a1671
commit b8037cd529

View File

@ -754,6 +754,7 @@ def main():
)
masked_lm_model.to(device2)
masked_lm_model.device = device2
masked_lm_model.eval()
else:
masked_lm_model = None