minor fixes

This commit is contained in:
Fangjun Kuang 2022-07-25 21:10:27 +08:00
parent 90dc5772ec
commit c1fc004df5

View File

@ -90,7 +90,7 @@ def main():
model = torch.jit.load(params.nn_model_filename)
device = torch.device("cpu")
if torch.cuda.is_available() and hasattr(
if torch.cuda.is_available() and not hasattr(
model.simple_lm_proj, "_packed_params"
):
device = torch.device("cuda", 0)
@ -105,6 +105,8 @@ def main():
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.suffix = "jit"
model.to(device)
model.device = device
model.unk_id = params.unk_id