mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Replace scale with lattice-score-scale.
This commit is contained in:
parent
f841581fff
commit
3dadffd2b6
@ -59,7 +59,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--scale",
|
"--lattice-score-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=1.0,
|
default=1.0,
|
||||||
help="The scale to be applied to `lattice.scores`."
|
help="The scale to be applied to `lattice.scores`."
|
||||||
@ -206,7 +206,7 @@ def decode_one_batch(
|
|||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
ref_texts=supervisions["text"],
|
ref_texts=supervisions["text"],
|
||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
scale=params.scale,
|
scale=params.lattice_score_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.method in ["1best", "nbest"]:
|
if params.method in ["1best", "nbest"]:
|
||||||
@ -220,9 +220,9 @@ def decode_one_batch(
|
|||||||
lattice=lattice,
|
lattice=lattice,
|
||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
use_double_scores=params.use_double_scores,
|
use_double_scores=params.use_double_scores,
|
||||||
scale=params.scale,
|
scale=params.lattice_score_scale,
|
||||||
)
|
)
|
||||||
key = f"no_rescore-scale-{params.scale}-{params.num_paths}"
|
key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}"
|
||||||
|
|
||||||
hyps = get_texts(best_path)
|
hyps = get_texts(best_path)
|
||||||
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
|
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
|
||||||
@ -243,7 +243,7 @@ def decode_one_batch(
|
|||||||
G=G,
|
G=G,
|
||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
lm_scale_list=lm_scale_list,
|
lm_scale_list=lm_scale_list,
|
||||||
scale=params.scale,
|
scale=params.lattice_score_scale,
|
||||||
)
|
)
|
||||||
elif params.method == "whole-lattice-rescoring":
|
elif params.method == "whole-lattice-rescoring":
|
||||||
best_path_dict = rescore_with_whole_lattice(
|
best_path_dict = rescore_with_whole_lattice(
|
||||||
@ -263,7 +263,7 @@ def decode_one_batch(
|
|||||||
memory_key_padding_mask=memory_key_padding_mask,
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
sos_id=sos_id,
|
sos_id=sos_id,
|
||||||
eos_id=eos_id,
|
eos_id=eos_id,
|
||||||
scale=params.scale,
|
scale=params.lattice_score_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert False, f"Unsupported decoding method: {params.method}"
|
assert False, f"Unsupported decoding method: {params.method}"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user