Disable decoder layers in pretrained.py if it is not used.

This commit is contained in:
Fangjun Kuang 2021-10-14 21:16:59 +08:00
parent 2de12b195e
commit b975bdef9e

View File

@ -226,7 +226,13 @@ def main():
args = parser.parse_args()
params = get_params()
if args.method != "attention-decoder":
# to save memory as the attention decoder
# will not be used
params.num_decoder_layers = 0
params.update(vars(args))
logging.info(f"{params}")
device = torch.device("cpu")
@ -248,7 +254,7 @@ def main():
)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"])
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()