Fix an error.

This commit is contained in:
Fangjun Kuang 2021-11-15 16:03:49 +08:00
parent 878fb40a12
commit 7f5d9a1671

View File

@ -993,6 +993,7 @@ def rescore_with_conformer_lm(
)
logging.info(f"path per utt: {path_per_utt}")
device = model.device
device2 = masked_lm_model.device
alignment = compute_alignment(tokens.to(device2), nbest.shape.to(device2))
@ -1037,6 +1038,7 @@ def rescore_with_conformer_lm(
# hyp - ref
masked_lm_scores = tgt_ll_list[1] - tgt_ll_list[0]
masked_lm_scores = masked_lm_scores.to(device)
# TODO(fangjun): Support passing a ragged tensor to `decoder_nll` directly.
token_ids = tokens.tolist()