Replace scale with lattice-score-scale.

This commit is contained in:
Fangjun Kuang 2021-08-19 18:07:17 +08:00
parent f841581fff
commit 3dadffd2b6

View File

@ -59,7 +59,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--scale", "--lattice-score-scale",
type=float, type=float,
default=1.0, default=1.0,
help="The scale to be applied to `lattice.scores`." help="The scale to be applied to `lattice.scores`."
@ -206,7 +206,7 @@ def decode_one_batch(
num_paths=params.num_paths, num_paths=params.num_paths,
ref_texts=supervisions["text"], ref_texts=supervisions["text"],
lexicon=lexicon, lexicon=lexicon,
scale=params.scale, scale=params.lattice_score_scale,
) )
if params.method in ["1best", "nbest"]: if params.method in ["1best", "nbest"]:
@ -220,9 +220,9 @@ def decode_one_batch(
lattice=lattice, lattice=lattice,
num_paths=params.num_paths, num_paths=params.num_paths,
use_double_scores=params.use_double_scores, use_double_scores=params.use_double_scores,
scale=params.scale, scale=params.lattice_score_scale,
) )
key = f"no_rescore-scale-{params.scale}-{params.num_paths}" key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}"
hyps = get_texts(best_path) hyps = get_texts(best_path)
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
@ -243,7 +243,7 @@ def decode_one_batch(
G=G, G=G,
num_paths=params.num_paths, num_paths=params.num_paths,
lm_scale_list=lm_scale_list, lm_scale_list=lm_scale_list,
scale=params.scale, scale=params.lattice_score_scale,
) )
elif params.method == "whole-lattice-rescoring": elif params.method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice( best_path_dict = rescore_with_whole_lattice(
@ -263,7 +263,7 @@ def decode_one_batch(
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id, sos_id=sos_id,
eos_id=eos_id, eos_id=eos_id,
scale=params.scale, scale=params.lattice_score_scale,
) )
else: else:
assert False, f"Unsupported decoding method: {params.method}" assert False, f"Unsupported decoding method: {params.method}"