From cd7a36b0a2cc75b05a0a7836ae098155700e68ea Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 26 Sep 2021 11:43:03 +0800 Subject: [PATCH] Rename lattice_score_scale to nbest_scale. --- .../recipes/librispeech/conformer_ctc.rst | 6 ++-- egs/librispeech/ASR/RESULTS.md | 2 +- egs/librispeech/ASR/conformer_ctc/decode.py | 14 ++++----- .../ASR/conformer_ctc/pretrained.py | 4 +-- egs/librispeech/ASR/tdnn_lstm_ctc/decode.py | 6 ++-- icefall/decode.py | 30 +++++++++---------- test/test_decode.py | 2 +- 7 files changed, 32 insertions(+), 32 deletions(-) diff --git a/docs/source/recipes/librispeech/conformer_ctc.rst b/docs/source/recipes/librispeech/conformer_ctc.rst index 40100bc5a..a8b0683f4 100644 --- a/docs/source/recipes/librispeech/conformer_ctc.rst +++ b/docs/source/recipes/librispeech/conformer_ctc.rst @@ -299,9 +299,9 @@ The commonly used options are: .. code-block:: $ cd egs/librispeech/ASR - $ ./conformer_ctc/decode.py --method attention-decoder --max-duration 30 --lattice-score-scale 0.5 + $ ./conformer_ctc/decode.py --method attention-decoder --max-duration 30 --nbest-scale 0.5 - - ``--lattice-score-scale`` + - ``--nbest-scale`` It is used to scale down lattice scores so that there are more unique paths for rescoring. @@ -577,7 +577,7 @@ The command to run HLG decoding + LM rescoring + attention decoder rescoring is: --G ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt \ --ngram-lm-scale 1.3 \ --attention-decoder-scale 1.2 \ - --lattice-score-scale 0.5 \ + --nbest-scale 0.5 \ --num-paths 100 \ --sos-id 1 \ --eos-id 1 \ diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index d04e912bf..43a46a30f 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -40,7 +40,7 @@ python conformer_ctc/train.py --bucketing-sampler True \ --full-libri True \ --world-size 4 -python conformer_ctc/decode.py --lattice-score-scale 0.5 \ +python conformer_ctc/decode.py --nbest-scale 0.5 \ --epoch 34 \ --avg 20 \ --method attention-decoder \ diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index b5b41c82e..48602f243 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -106,7 +106,7 @@ def get_parser(): ) parser.add_argument( - "--lattice-score-scale", + "--nbest-scale", type=float, default=0.5, help="""The scale to be applied to `lattice.scores`. @@ -250,12 +250,12 @@ def decode_one_batch( num_paths=params.num_paths, ref_texts=supervisions["text"], word_table=word_table, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, oov="", ) hyps = get_texts(best_path) hyps = [[word_table[i] for i in ids] for ids in hyps] - key = f"oracle_{params.num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa return {key: hyps} if params.method in ["1best", "nbest"]: @@ -269,9 +269,9 @@ def decode_one_batch( lattice=lattice, num_paths=params.num_paths, use_double_scores=params.use_double_scores, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) - key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa hyps = get_texts(best_path) hyps = [[word_table[i] for i in ids] for ids in hyps] @@ -293,7 +293,7 @@ def decode_one_batch( G=G, num_paths=params.num_paths, lm_scale_list=lm_scale_list, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) elif params.method == "whole-lattice-rescoring": best_path_dict = rescore_with_whole_lattice( @@ -319,7 +319,7 @@ def decode_one_batch( memory_key_padding_mask=memory_key_padding_mask, sos_id=sos_id, eos_id=eos_id, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) else: assert False, f"Unsupported decoding method: {params.method}" diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index c924b87bb..cdd21f410 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -125,7 +125,7 @@ def get_parser(): ) parser.add_argument( - "--lattice-score-scale", + "--nbest-scale", type=float, default=0.5, help=""" @@ -336,7 +336,7 @@ def main(): memory_key_padding_mask=memory_key_padding_mask, sos_id=params.sos_id, eos_id=params.eos_id, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ngram_lm_scale=params.ngram_lm_scale, attention_scale=params.attention_decoder_scale, ) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 1e91b1008..5f01e60da 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -97,7 +97,7 @@ def get_parser(): ) parser.add_argument( - "--lattice-score-scale", + "--nbest-scale", type=float, default=0.5, help="""The scale to be applied to `lattice.scores`. @@ -229,7 +229,7 @@ def decode_one_batch( lattice=lattice, num_paths=params.num_paths, use_double_scores=params.use_double_scores, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) key = f"no_rescore-{params.num_paths}" hyps = get_texts(best_path) @@ -248,7 +248,7 @@ def decode_one_batch( G=G, num_paths=params.num_paths, lm_scale_list=lm_scale_list, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) else: best_path_dict = rescore_with_whole_lattice( diff --git a/icefall/decode.py b/icefall/decode.py index e678e4622..2103885b4 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -180,7 +180,7 @@ class Nbest(object): lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True, - lattice_score_scale: float = 0.5, + nbest_scale: float = 0.5, ) -> "Nbest": """Construct an Nbest object by **sampling** `num_paths` from a lattice. @@ -206,7 +206,7 @@ class Nbest(object): Return an Nbest instance. """ saved_scores = lattice.scores.clone() - lattice.scores *= lattice_score_scale + lattice.scores *= nbest_scale # path is a ragged tensor with dtype torch.int32. # It has three axes [utt][path][arc_pos] path = k2.random_paths( @@ -446,7 +446,7 @@ def nbest_decoding( lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True, - lattice_score_scale: float = 1.0, + nbest_scale: float = 1.0, ) -> k2.Fsa: """It implements something like CTC prefix beam search using n-best lists. @@ -474,7 +474,7 @@ def nbest_decoding( use_double_scores: True to use double precision floating point in the computation. False to use single precision. - lattice_score_scale: + nbest_scale: It's the scale applied to the `lattice.scores`. A smaller value leads to more unique paths at the risk of missing the correct path. Returns: @@ -484,7 +484,7 @@ def nbest_decoding( lattice=lattice, num_paths=num_paths, use_double_scores=use_double_scores, - lattice_score_scale=lattice_score_scale, + nbest_scale=nbest_scale, ) # nbest.fsa.scores contains 0s @@ -505,7 +505,7 @@ def nbest_oracle( ref_texts: List[str], word_table: k2.SymbolTable, use_double_scores: bool = True, - lattice_score_scale: float = 0.5, + nbest_scale: float = 0.5, oov: str = "", ) -> Dict[str, List[List[int]]]: """Select the best hypothesis given a lattice and a reference transcript. @@ -517,7 +517,7 @@ def nbest_oracle( The decoding result returned from this function is the best result that we can obtain using n-best decoding with all kinds of rescoring techniques. - This function is useful to tune the value of `lattice_score_scale`. + This function is useful to tune the value of `nbest_scale`. Args: lattice: @@ -533,7 +533,7 @@ def nbest_oracle( use_double_scores: True to use double precision for computation. False to use single precision. - lattice_score_scale: + nbest_scale: It's the scale applied to the lattice.scores. A smaller value yields more unique paths. oov: @@ -549,7 +549,7 @@ def nbest_oracle( lattice=lattice, num_paths=num_paths, use_double_scores=use_double_scores, - lattice_score_scale=lattice_score_scale, + nbest_scale=nbest_scale, ) hyps = nbest.build_levenshtein_graphs() @@ -590,7 +590,7 @@ def rescore_with_n_best_list( G: k2.Fsa, num_paths: int, lm_scale_list: List[float], - lattice_score_scale: float = 1.0, + nbest_scale: float = 1.0, use_double_scores: bool = True, ) -> Dict[str, k2.Fsa]: """Rescore an n-best list with an n-gram LM. @@ -607,7 +607,7 @@ def rescore_with_n_best_list( Size of nbest list. lm_scale_list: A list of float representing LM score scales. - lattice_score_scale: + nbest_scale: Scale to be applied to ``lattice.score`` when sampling paths using ``k2.random_paths``. use_double_scores: @@ -631,7 +631,7 @@ def rescore_with_n_best_list( lattice=lattice, num_paths=num_paths, use_double_scores=use_double_scores, - lattice_score_scale=lattice_score_scale, + nbest_scale=nbest_scale, ) # nbest.fsa.scores are all 0s at this point @@ -769,7 +769,7 @@ def rescore_with_attention_decoder( memory_key_padding_mask: Optional[torch.Tensor], sos_id: int, eos_id: int, - lattice_score_scale: float = 1.0, + nbest_scale: float = 1.0, ngram_lm_scale: Optional[float] = None, attention_scale: Optional[float] = None, use_double_scores: bool = True, @@ -796,7 +796,7 @@ def rescore_with_attention_decoder( The token ID for SOS. eos_id: The token ID for EOS. - lattice_score_scale: + nbest_scale: It's the scale applied to `lattice.scores`. A smaller value leads to more unique paths at the risk of missing the correct path. ngram_lm_scale: @@ -812,7 +812,7 @@ def rescore_with_attention_decoder( lattice=lattice, num_paths=num_paths, use_double_scores=use_double_scores, - lattice_score_scale=lattice_score_scale, + nbest_scale=nbest_scale, ) # nbest.fsa.scores are all 0s at this point diff --git a/test/test_decode.py b/test/test_decode.py index 7ef127781..97964ac67 100644 --- a/test/test_decode.py +++ b/test/test_decode.py @@ -43,7 +43,7 @@ def test_nbest_from_lattice(): lattice=lattice, num_paths=10, use_double_scores=True, - lattice_score_scale=0.5, + nbest_scale=0.5, ) # each lattice has only 4 distinct paths that have different word sequences: # 10->30