diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 470bdd682..6b0ae0931 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -63,7 +63,10 @@ def get_parser(): type=float, default=1.0, help="The scale to be applied to `lattice.scores`." - "A smaller value results in more unique paths", + "It's needed if you use any kinds of n-best based rescoring. " + "Currently, it is used when the decoding method is: nbest, " + "nbest-rescoring, attention-decoder, and nbest-oracle. " + "A smaller value results in more unique paths.", ) return parser @@ -96,6 +99,8 @@ def get_params() -> AttributeDict: # - whole-lattice-rescoring # - attention-decoder # - nbest-oracle + # "method": "nbest", + # "method": "nbest-rescoring", # "method": "whole-lattice-rescoring", "method": "attention-decoder", # "method": "nbest-oracle", @@ -215,8 +220,9 @@ def decode_one_batch( lattice=lattice, num_paths=params.num_paths, use_double_scores=params.use_double_scores, + scale=params.scale, ) - key = f"no_rescore-{params.num_paths}" + key = f"no_rescore-scale-{params.scale}-{params.num_paths}" hyps = get_texts(best_path) hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] @@ -237,6 +243,7 @@ def decode_one_batch( G=G, num_paths=params.num_paths, lm_scale_list=lm_scale_list, + scale=params.scale, ) elif params.method == "whole-lattice-rescoring": best_path_dict = rescore_with_whole_lattice( @@ -256,6 +263,7 @@ def decode_one_batch( memory_key_padding_mask=memory_key_padding_mask, sos_id=sos_id, eos_id=eos_id, + scale=params.scale, ) else: assert False, f"Unsupported decoding method: {params.method}" diff --git a/icefall/decode.py b/icefall/decode.py index 43524a02a..bdcab23f3 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -9,6 +9,36 @@ import torch.nn as nn from icefall.lexicon import Lexicon +def _get_random_paths( + lattice: k2.Fsa, + num_paths: int, + use_double_scores: bool = True, + scale: float = 1.0, +): + """ + Args: + lattice: + The decoding lattice, returned by :func:`get_lattice`. + num_paths: + It specifies the size `n` in n-best. Note: Paths are selected randomly + and those containing identical word sequences are remove dand only one + of them is kept. + use_double_scores: + True to use double precision floating point in the computation. + False to use single precision. + scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + Returns: + Return a k2.RaggedInt with 3 axes [seq][path][arc_pos] + """ + saved_scores = lattice.scores.clone() + lattice.scores *= scale + path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True) + lattice.scores = saved_scores + return path + + def _intersect_device( a_fsas: k2.Fsa, b_fsas: k2.Fsa, @@ -132,7 +162,10 @@ def one_best_decoding( def nbest_decoding( - lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True + lattice: k2.Fsa, + num_paths: int, + use_double_scores: bool = True, + scale: float = 1.0, ) -> k2.Fsa: """It implements something like CTC prefix beam search using n-best lists. @@ -155,12 +188,18 @@ def nbest_decoding( use_double_scores: True to use double precision floating point in the computation. False to use single precision. + scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. Returns: An FsaVec containing linear FSAs. """ - # First, extract `num_paths` paths for each sequence. - # path is a k2.RaggedInt with axes [seq][path][arc_pos] - path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True) + path = _get_random_paths( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + scale=scale, + ) # word_seq is a k2.RaggedInt sharing the same shape as `path` # but it contains word IDs. Note that it also contains 0s and -1s. @@ -323,7 +362,11 @@ def compute_am_and_lm_scores( def rescore_with_n_best_list( - lattice: k2.Fsa, G: k2.Fsa, num_paths: int, lm_scale_list: List[float] + lattice: k2.Fsa, + G: k2.Fsa, + num_paths: int, + lm_scale_list: List[float], + scale: float = 1.0, ) -> Dict[str, k2.Fsa]: """Decode using n-best list with LM rescoring. @@ -345,6 +388,9 @@ def rescore_with_n_best_list( It is the size `n` in `n-best` list. lm_scale_list: A list containing lm_scale values. + scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. Returns: A dict of FsaVec, whose key is an lm_scale and the value is the best decoding path for each sequence in the lattice. @@ -359,9 +405,12 @@ def rescore_with_n_best_list( assert G.device == device assert hasattr(G, "aux_labels") is False - # First, extract `num_paths` paths for each sequence. - # path is a k2.RaggedInt with axes [seq][path][arc_pos] - path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True) + path = _get_random_paths( + lattice=lattice, + num_paths=num_paths, + use_double_scores=True, + scale=scale, + ) # word_seq is a k2.RaggedInt sharing the same shape as `path` # but it contains word IDs. Note that it also contains 0s and -1s. @@ -587,11 +636,12 @@ def nbest_oracle( when calling this function, while its value contains the decoding output. `len(ans_dict) == len(ref_texts)` """ - saved_scores = lattice.scores.clone() - - lattice.scores *= scale - path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True) - lattice.scores = saved_scores + path = _get_random_paths( + lattice=lattice, + num_paths=num_paths, + use_double_scores=True, + scale=scale, + ) word_seq = k2.index(lattice.aux_labels, path) word_seq = k2.ragged.remove_values_leq(word_seq, 0) @@ -630,6 +680,7 @@ def rescore_with_attention_decoder( memory_key_padding_mask: torch.Tensor, sos_id: int, eos_id: int, + scale: float = 1.0, ) -> Dict[str, k2.Fsa]: """This function extracts n paths from the given lattice and uses an attention decoder to rescore them. The path with the highest @@ -653,6 +704,9 @@ def rescore_with_attention_decoder( The token ID for SOS. eos_id: The token ID for EOS. + scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. Returns: A dict of FsaVec, whose key contains a string ngram_lm_scale_attention_scale and the value is the @@ -660,7 +714,12 @@ def rescore_with_attention_decoder( """ # First, extract `num_paths` paths for each sequence. # path is a k2.RaggedInt with axes [seq][path][arc_pos] - path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True) + path = _get_random_paths( + lattice=lattice, + num_paths=num_paths, + use_double_scores=True, + scale=scale, + ) # word_seq is a k2.RaggedInt sharing the same shape as `path` # but it contains word IDs. Note that it also contains 0s and -1s.