Minor fixes.

This commit is contained in:
Fangjun Kuang 2021-09-18 16:46:29 +08:00
parent 77993b9552
commit 306c9e1398

View File

@ -255,7 +255,7 @@ def decode_one_batch(
)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
key = f"oracle_{num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa
key = f"oracle_{params.num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa
return {key: hyps}
if params.method in ["1best", "nbest"]:
@ -386,9 +386,6 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
# TODO: remove it
if batch_idx > 100:
break
hyps_dict = decode_one_batch(
params=params,