mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
Fix an error.
This commit is contained in:
parent
878fb40a12
commit
7f5d9a1671
@ -993,6 +993,7 @@ def rescore_with_conformer_lm(
|
|||||||
)
|
)
|
||||||
logging.info(f"path per utt: {path_per_utt}")
|
logging.info(f"path per utt: {path_per_utt}")
|
||||||
|
|
||||||
|
device = model.device
|
||||||
device2 = masked_lm_model.device
|
device2 = masked_lm_model.device
|
||||||
|
|
||||||
alignment = compute_alignment(tokens.to(device2), nbest.shape.to(device2))
|
alignment = compute_alignment(tokens.to(device2), nbest.shape.to(device2))
|
||||||
@ -1037,6 +1038,7 @@ def rescore_with_conformer_lm(
|
|||||||
|
|
||||||
# hyp - ref
|
# hyp - ref
|
||||||
masked_lm_scores = tgt_ll_list[1] - tgt_ll_list[0]
|
masked_lm_scores = tgt_ll_list[1] - tgt_ll_list[0]
|
||||||
|
masked_lm_scores = masked_lm_scores.to(device)
|
||||||
|
|
||||||
# TODO(fangjun): Support passing a ragged tensor to `decoder_nll` directly.
|
# TODO(fangjun): Support passing a ragged tensor to `decoder_nll` directly.
|
||||||
token_ids = tokens.tolist()
|
token_ids = tokens.tolist()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user