From c1fc004df543949d03969321eeb94c4c57c0ce4a Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 25 Jul 2022 21:10:27 +0800 Subject: [PATCH] minor fixes --- .../ASR/pruned_transducer_stateless3/jit_decode.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_decode.py index 7c306d4f1..c976d7992 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_decode.py @@ -90,7 +90,7 @@ def main(): model = torch.jit.load(params.nn_model_filename) device = torch.device("cpu") - if torch.cuda.is_available() and hasattr( + if torch.cuda.is_available() and not hasattr( model.simple_lm_proj, "_packed_params" ): device = torch.device("cuda", 0) @@ -105,6 +105,8 @@ def main(): params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() + params.suffix = "jit" + model.to(device) model.device = device model.unk_id = params.unk_id