mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
minor fixes
This commit is contained in:
parent
90dc5772ec
commit
c1fc004df5
@ -90,7 +90,7 @@ def main():
|
||||
model = torch.jit.load(params.nn_model_filename)
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available() and hasattr(
|
||||
if torch.cuda.is_available() and not hasattr(
|
||||
model.simple_lm_proj, "_packed_params"
|
||||
):
|
||||
device = torch.device("cuda", 0)
|
||||
@ -105,6 +105,8 @@ def main():
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
params.suffix = "jit"
|
||||
|
||||
model.to(device)
|
||||
model.device = device
|
||||
model.unk_id = params.unk_id
|
||||
|
Loading…
x
Reference in New Issue
Block a user