diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index dd53a1984..21a2cecd2 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -42,6 +42,7 @@ from icefall.decode2 import ( nbest_decoding, nbest_oracle as nbest_oracle2, rescore_with_n_best_list as rescore_with_n_best_list2, + rescore_with_whole_lattice as rescore_with_whole_lattice2, ) from icefall.lexicon import Lexicon from icefall.utils import ( @@ -261,9 +262,7 @@ def decode_one_batch( ) hyps = get_texts(best_path) hyps = [[word_table[i] for i in ids] for ids in hyps] - key = ( - f"oracle_{num_paths}_lattice_score_scale_{lattice_score_scale}" - ) + key = f"oracle_{num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa return {key: hyps} else: return nbest_oracle( @@ -322,9 +321,19 @@ def decode_one_batch( scale=params.lattice_score_scale, ) elif params.method == "whole-lattice-rescoring": - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list - ) + if True: + # TODO: remove "else" branch + best_path_dict = rescore_with_whole_lattice2( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + else: + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) elif params.method == "attention-decoder": # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. rescored_lattice = rescore_with_whole_lattice( @@ -345,10 +354,14 @@ def decode_one_batch( assert False, f"Unsupported decoding method: {params.method}" ans = dict() - for lm_scale_str, best_path in best_path_dict.items(): - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - ans[lm_scale_str] = hyps + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + for lm_scale in lm_scale_list: + ans[lm_scale_str] = [[] * lattice.shape[0]] return ans diff --git a/icefall/decode2.py b/icefall/decode2.py index 6b68456b9..bb38d6026 100644 --- a/icefall/decode2.py +++ b/icefall/decode2.py @@ -17,7 +17,8 @@ # NOTE: This file is a refactor of decode.py # We will delete decode.py and rename this file to decode.py -from typing import Dict, List +import logging +from typing import Dict, List, Optional, Union import k2 import torch @@ -505,6 +506,27 @@ def rescore_with_n_best_list( ) -> Dict[str, k2.Fsa]: """Rescore a nbest list with an n-gram LM. The path with a maximum score is used as the decoding output. + + Args: + lattice: + An FsaVec with axes [utt][state][arc]. It must have the following + attributes: ``aux_labels`` and ``lm_scores``. Its labels are + token IDs and ``aux_labels`` word IDs. + G: + An FsaVec containing only a single FSA. It is an n-gram LM. + num_paths: + Size of nbest list. + lm_scale_list: + A list of float representing LM score scales. + lattice_score_scale: + Scale to be applied to ``lattice.score`` when sampling paths + using ``k2.random_paths``. + use_double_scores: + True to use double precision during computation. False to use + single precision. + Returns: + A dict of FsaVec, whose key is an lm_scale and the value is the + best decoding path for each utterance in the lattice. """ device = lattice.device @@ -543,3 +565,77 @@ def rescore_with_n_best_list( key = f"lm_scale_{lm_scale}" ans[key] = best_path return ans + + +def rescore_with_whole_lattice( + lattice: k2.Fsa, + G_with_epsilon_loops: k2.Fsa, + lm_scale_list: Optional[List[float]] = None, + use_double_scores: bool = True, +) -> Union[k2.Fsa, Dict[str, k2.Fsa]]: + # This is not an Nbest based coding method + assert hasattr(lattice, "lm_scores") + assert G_with_epsilon_loops.shape == (1, None, None) + + device = lattice.device + lattice.scores = lattice.scores - lattice.lm_scores + # We will use lm_scores from G, so remove lats.lm_scores here + del lattice.lm_scores + + assert hasattr(G_with_epsilon_loops, "lm_scores") + + # Now, lattice.scores contains only am_scores + + # inv_lattice has word IDs as labels. + # Its aux_labels are token IDs + inv_lattice = k2.invert(lattice) + num_seqs = lattice.shape[0] + + b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32) + + max_loop_count = 10 + loop_count = 0 + while loop_count <= max_loop_count: + loop_count += 1 + try: + rescoring_lattice = k2.intersect_device( + G_with_epsilon_loops, + inv_lattice, + b_to_a_map, + sorted_match_a=True, + ) + rescoring_lattice = k2.top_sort(k2.connect(rescoring_lattice)) + break + except RuntimeError as e: + logging.info(f"Caught exception:\n{e}\n") + logging.info( + f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}" + ) + + # NOTE(fangjun): The choice of the threshold 1e-9 is arbitrary here + # to avoid OOM. You may need to fine tune it. + inv_lattice = k2.prune_on_arc_post(inv_lattice, 1e-9, True) + logging.info( + f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}" + ) + if loop_count > max_loop_count: + logging.info("Return None as the resulting lattice is too large") + return None + + # lat has token IDs as labels + # and word IDs as aux_labels. + lat = k2.invert(rescoring_lattice) + + if lm_scale_list is None: + return lat + + ans = dict() + saved_am_scores = lat.scores - lat.lm_scores + for lm_scale in lm_scale_list: + am_scores = saved_am_scores / lm_scale + lat.scores = am_scores + lat.lm_scores + + best_path = k2.shortest_path(lat, use_double_scores=use_double_scores) + key = f"lm_scale_{lm_scale}_yy" + ans[key] = best_path + return ans