diff --git a/icefall/decode.py b/icefall/decode.py index 736e5200b..d8571f8ea 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -993,6 +993,7 @@ def rescore_with_conformer_lm( ) logging.info(f"path per utt: {path_per_utt}") + device = model.device device2 = masked_lm_model.device alignment = compute_alignment(tokens.to(device2), nbest.shape.to(device2)) @@ -1037,6 +1038,7 @@ def rescore_with_conformer_lm( # hyp - ref masked_lm_scores = tgt_ll_list[1] - tgt_ll_list[0] + masked_lm_scores = masked_lm_scores.to(device) # TODO(fangjun): Support passing a ragged tensor to `decoder_nll` directly. token_ids = tokens.tolist()