diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 4f22a6f6c..604ac005e 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -59,7 +59,7 @@ def get_parser(): ) parser.add_argument( - "--scale", + "--lattice-score-scale", type=float, default=1.0, help="The scale to be applied to `lattice.scores`." @@ -206,7 +206,7 @@ def decode_one_batch( num_paths=params.num_paths, ref_texts=supervisions["text"], lexicon=lexicon, - scale=params.scale, + scale=params.lattice_score_scale, ) if params.method in ["1best", "nbest"]: @@ -220,9 +220,9 @@ def decode_one_batch( lattice=lattice, num_paths=params.num_paths, use_double_scores=params.use_double_scores, - scale=params.scale, + scale=params.lattice_score_scale, ) - key = f"no_rescore-scale-{params.scale}-{params.num_paths}" + key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" hyps = get_texts(best_path) hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] @@ -243,7 +243,7 @@ def decode_one_batch( G=G, num_paths=params.num_paths, lm_scale_list=lm_scale_list, - scale=params.scale, + scale=params.lattice_score_scale, ) elif params.method == "whole-lattice-rescoring": best_path_dict = rescore_with_whole_lattice( @@ -263,7 +263,7 @@ def decode_one_batch( memory_key_padding_mask=memory_key_padding_mask, sos_id=sos_id, eos_id=eos_id, - scale=params.scale, + scale=params.lattice_score_scale, ) else: assert False, f"Unsupported decoding method: {params.method}"