diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 37fb58bbb..dd53a1984 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -32,13 +32,17 @@ from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.decode import get_lattice from icefall.decode import ( - one_best_decoding, + one_best_decoding, # done rescore_with_attention_decoder, - rescore_with_n_best_list, + rescore_with_n_best_list, # done rescore_with_whole_lattice, - nbest_oracle, + nbest_oracle, # done +) +from icefall.decode2 import ( + nbest_decoding, + nbest_oracle as nbest_oracle2, + rescore_with_n_best_list as rescore_with_n_best_list2, ) -from icefall.decode2 import nbest_decoding, nbest_oracle as nbest_oracle2 from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -257,7 +261,10 @@ def decode_one_batch( ) hyps = get_texts(best_path) hyps = [[word_table[i] for i in ids] for ids in hyps] - return {"nbest-orcale": hyps} + key = ( + f"oracle_{num_paths}_lattice_score_scale_{lattice_score_scale}" + ) + return {key: hyps} else: return nbest_oracle( lattice=lattice, @@ -297,13 +304,23 @@ def decode_one_batch( lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] if params.method == "nbest-rescoring": - best_path_dict = rescore_with_n_best_list( - lattice=lattice, - G=G, - num_paths=params.num_paths, - lm_scale_list=lm_scale_list, - scale=params.lattice_score_scale, - ) + if True: + # TODO: remove the "else" branch + best_path_dict = rescore_with_n_best_list2( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + lattice_score_scale=params.lattice_score_scale, + ) + else: + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + scale=params.lattice_score_scale, + ) elif params.method == "whole-lattice-rescoring": best_path_dict = rescore_with_whole_lattice( lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list @@ -385,7 +402,8 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] - if batch_idx > 20: + # TODO: remove it + if batch_idx > 100: break hyps_dict = decode_one_batch( diff --git a/icefall/decode2.py b/icefall/decode2.py index 1c393a920..6b68456b9 100644 --- a/icefall/decode2.py +++ b/icefall/decode2.py @@ -25,6 +25,47 @@ import torch from icefall.utils import get_texts +def _intersect_device( + a_fsas: k2.Fsa, + b_fsas: k2.Fsa, + b_to_a_map: torch.Tensor, + sorted_match_a: bool, + batch_size: int = 50, +) -> k2.Fsa: + """This is a wrapper of k2.intersect_device and its purpose is to split + b_fsas into several batches and process each batch separately to avoid + CUDA OOM error. + + The arguments and return value of this function are the same as + k2.intersect_device. + """ + num_fsas = b_fsas.shape[0] + if num_fsas <= batch_size: + return k2.intersect_device( + a_fsas, b_fsas, b_to_a_map=b_to_a_map, sorted_match_a=sorted_match_a + ) + + num_batches = (num_fsas + batch_size - 1) // batch_size + splits = [] + for i in range(num_batches): + start = i * batch_size + end = min(start + batch_size, num_fsas) + splits.append((start, end)) + + ans = [] + for start, end in splits: + indexes = torch.arange(start, end).to(b_to_a_map) + + fsas = k2.index_fsa(b_fsas, indexes) + b_to_a = k2.index_select(b_to_a_map, indexes) + path_lattice = k2.intersect_device( + a_fsas, fsas, b_to_a_map=b_to_a, sorted_match_a=sorted_match_a + ) + ans.append(path_lattice) + + return k2.cat(ans) + + # TODO(fangjun): Use Kangwei's C++ implementation that also # supports List[List[int]] def levenshtein_graph(symbol_ids: List[int]) -> k2.Fsa: @@ -181,18 +222,22 @@ class Nbest(object): fsa = k2.linear_fsa(labels) fsa.aux_labels = aux_labels + # Caution: fsa.scores are all 0s. return Nbest(fsa=fsa, shape=utt_to_path_shape) def intersect(self, lattice: k2.Fsa, use_double_scores=True) -> "Nbest": """Intersect this Nbest object with a lattice and get 1-best path from the resulting FsaVec. - Caution: - We assume `self.fsa.labels` and `lattice.labels` are token IDs. + The purpose of this function is to attach scores to an Nbest. + Args: lattice: - An FsaVec with axes [utt][state][arc] + An FsaVec with axes [utt][state][arc]. If it has `aux_labels`, then + we assume its `labels` are token IDs and `aux_labels` are word IDs. + If it has only `labels`, we assume it `labels` are word IDs. + use_double_scores: True to use double precision when computing shortest path. False to use single precision. @@ -201,16 +246,6 @@ class Nbest(object): while its `fsa` is the 1-best path from intersecting `self.fsa` and `lattice`. """ - assert ( - self.fsa.device == lattice.device - ), f"{self.fsa.device} vs {lattice.device}" - - assert len(lattice.shape) == 3, f"{lattice.shape}" - - assert ( - lattice.arcs.dim0() == self.shape.dim0 - ), f"{lattice.arcs.dim0()} vs {self.shape.dim0}" - # Note: We view each linear FSA as a word sequence # and we use the passed lattice to give each word sequence a score. # @@ -221,8 +256,10 @@ class Nbest(object): # We use a word fsa to intersect with k2.invert(lattice) word_fsa = k2.invert(self.fsa) - # delete token IDs as it is not needed - del word_fsa.aux_labels + if hasattr(lattice, "aux_labels"): + # delete token IDs as it is not needed + del word_fsa.aux_labels + word_fsa.scores.zero_() word_fsa_with_epsilon_loops = k2.remove_epsilon_and_add_self_loops( word_fsa @@ -230,17 +267,28 @@ class Nbest(object): path_to_utt_map = self.shape.row_ids(1) - # lattice has token IDs as labels and word IDs as aux_labels. - # inv_lattice has word IDs as labels and token IDs as aux_labels - inv_lattice = k2.invert(lattice) - inv_lattice = k2.arc_sort(inv_lattice) + if hasattr(lattice, "aux_labels"): + # lattice has token IDs as labels and word IDs as aux_labels. + # inv_lattice has word IDs as labels and token IDs as aux_labels + inv_lattice = k2.invert(lattice) + inv_lattice = k2.arc_sort(inv_lattice) + else: + inv_lattice = k2.arc_sort(lattice) - path_lattice = k2.intersect_device( - inv_lattice, - word_fsa_with_epsilon_loops, - b_to_a_map=path_to_utt_map, - sorted_match_a=True, - ) + if inv_lattice.shape[0] == 1: + path_lattice = _intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=torch.zeros_like(path_to_utt_map), + sorted_match_a=True, + ) + else: + path_lattice = _intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=path_to_utt_map, + sorted_match_a=True, + ) # path_lattice has word IDs as labels and token IDs as aux_labels path_lattice = k2.top_sort(k2.connect(path_lattice)) @@ -254,6 +302,29 @@ class Nbest(object): return Nbest(fsa=one_best, shape=self.shape) + def compute_am_scores(self) -> k2.RaggedTensor: + """Compute AM scores of each linear FSA (i.e., each path within + an utterance). + + Hint: + `self.fsa.scores` contains two parts: am scores and lm scores. + + Returns: + Return a ragged tensor with 2 axes [utt][path_scores]. + Its dtype is torch.float64. + """ + saved_scores = self.fsa.scores + + # The `scores` of every arc consists of `am_scores` and `lm_scores` + self.fsa.scores = self.fsa.scores - self.fsa.lm_scores + + am_scores = self.fsa.get_tot_scores( + use_double_scores=True, log_semiring=False + ) + self.fsa.scores = saved_scores + + return k2.RaggedTensor(self.shape, am_scores) + def tot_scores(self) -> k2.RaggedTensor: """Get total scores of the FSAs in this Nbest. @@ -263,10 +334,11 @@ class Nbest(object): Returns: Return a ragged tensor with two axes [utt][path_scores]. + Its dtype is torch.float64. """ # Use single precision since there are only additions. scores = self.fsa.get_tot_scores( - use_double_scores=False, log_semiring=False + use_double_scores=True, log_semiring=False ) return k2.RaggedTensor(self.shape, scores) @@ -317,11 +389,13 @@ def nbest_decoding( use_double_scores=use_double_scores, lattice_score_scale=lattice_score_scale, ) + # nbest.fsa.scores contains 0s nbest = nbest.intersect(lattice) + # now nbest.fsa.scores gets assigned - # max_indexes contains the indexes for the max scores - # of paths within an utterance. + # max_indexes contains the indexes for the path with the maximum score + # within an utterance. max_indexes = nbest.tot_scores().argmax() best_path = k2.index_fsa(nbest.fsa, max_indexes) @@ -419,3 +493,53 @@ def nbest_oracle( best_path = k2.index_fsa(nbest.fsa, max_indexes) return best_path + + +def rescore_with_n_best_list( + lattice: k2.Fsa, + G: k2.Fsa, + num_paths: int, + lm_scale_list: List[float], + lattice_score_scale: float = 1.0, + use_double_scores: bool = True, +) -> Dict[str, k2.Fsa]: + """Rescore a nbest list with an n-gram LM. + The path with a maximum score is used as the decoding output. + """ + device = lattice.device + + assert len(lattice.shape) == 3 + assert hasattr(lattice, "aux_labels") + assert hasattr(lattice, "lm_scores") + + assert G.shape == (1, None, None) + assert G.device == device + assert hasattr(G, "aux_labels") is False + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + lattice_score_scale=lattice_score_scale, + ) + # nbest.fsa.scores are all 0s at this point + + nbest = nbest.intersect(lattice) + # Now nbest.fsa has it scores set + assert hasattr(nbest.fsa, "lm_scores") + + am_scores = nbest.compute_am_scores() + + nbest = nbest.intersect(G) + # Now nbest contains only lm scores + lm_scores = nbest.tot_scores() + + ans = dict() + for lm_scale in lm_scale_list: + tot_scores = am_scores.values / lm_scale + lm_scores.values + tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + max_indexes = tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, max_indexes) + key = f"lm_scale_{lm_scale}" + ans[key] = best_path + return ans