From 7f5d9a167157e3388b91c3b554ec28278b5426ab Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 15 Nov 2021 16:03:49 +0800 Subject: [PATCH] Fix an error. --- icefall/decode.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/icefall/decode.py b/icefall/decode.py index 736e5200b..d8571f8ea 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -993,6 +993,7 @@ def rescore_with_conformer_lm( ) logging.info(f"path per utt: {path_per_utt}") + device = model.device device2 = masked_lm_model.device alignment = compute_alignment(tokens.to(device2), nbest.shape.to(device2)) @@ -1037,6 +1038,7 @@ def rescore_with_conformer_lm( # hyp - ref masked_lm_scores = tgt_ll_list[1] - tgt_ll_list[0] + masked_lm_scores = masked_lm_scores.to(device) # TODO(fangjun): Support passing a ragged tensor to `decoder_nll` directly. token_ids = tokens.tolist()