From ec9e4cffe273f4cb3e9fe78d6a9ba2d5f61e3b72 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 21 Jun 2022 15:08:07 +0800 Subject: [PATCH] Support using log_add in LG decoding with fast_beam_search. --- .../beam_search.py | 46 +++++++++++++++++-- .../ASR/pruned_transducer_stateless/decode.py | 17 +++++-- 2 files changed, 53 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index 2be509e75..a175a8e03 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -86,6 +86,7 @@ def fast_beam_search_nbest( num_paths: int, nbest_scale: float = 0.5, use_double_scores: bool = True, + use_max: bool = True, ) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. @@ -121,6 +122,8 @@ def fast_beam_search_nbest( use_double_scores: True to use double precision for computation. False to use single precision. + use_max: + False to use log-add to compute total scores. True to use max. Returns: Return the decoded result. """ @@ -141,14 +144,47 @@ def fast_beam_search_nbest( nbest_scale=nbest_scale, ) - # at this point, nbest.fsa.scores are all zeros. + # The following code is modified from nbest.intersect() + word_fsa = k2.invert(nbest.fsa) + if hasattr(lattice, "aux_labels"): + # delete token IDs as it is not needed + del word_fsa.aux_labels + word_fsa.scores.zero_() + word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa) + path_to_utt_map = nbest.shape.row_ids(1) - nbest = nbest.intersect(lattice) - # Now nbest.fsa.scores contains acoustic scores + if hasattr(lattice, "aux_labels"): + # lattice has token IDs as labels and word IDs as aux_labels. + # inv_lattice has word IDs as labels and token IDs as aux_labels + inv_lattice = k2.invert(lattice) + inv_lattice = k2.arc_sort(inv_lattice) + else: + inv_lattice = k2.arc_sort(lattice) - max_indexes = nbest.tot_scores().argmax() + if inv_lattice.shape[0] == 1: + path_lattice = k2.intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=torch.zeros_like(path_to_utt_map), + sorted_match_a=True, + ) + else: + path_lattice = k2.intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=path_to_utt_map, + sorted_match_a=True, + ) - best_path = k2.index_fsa(nbest.fsa, max_indexes) + # path_lattice has word IDs as labels and token IDs as aux_labels + path_lattice = k2.top_sort(k2.connect(path_lattice)) + tot_scores = path_lattice.get_tot_scores( + use_double_scores=use_double_scores, + log_semiring=(False if use_max else True), + ) + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + best_hyp_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) hyps = get_texts(best_path) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index e8aae7776..45c7d4437 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -154,7 +154,7 @@ def get_parser(): parser.add_argument( "--lang-dir", - type=str, + type=Path, default="data/lang_bpe_500", help="The lang dir containing word table and LG graph", ) @@ -195,8 +195,8 @@ def get_parser(): type=str2bool, default=False, help="""Whether to use an LG graph for FSA-based beam search. - Used only when --decoding_method is fast_beam_search. If setting true, - it assumes there is an LG.pt file in lang_dir.""", + Used only when --decoding_method is fast_beam_search. If true, + it uses lang_dir/LG.pt during decoding.""", ) parser.add_argument( @@ -320,6 +320,7 @@ def decode_one_batch( # at entry, feature is (N, T, C) supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) encoder_out, encoder_out_lens = model.encoder( @@ -351,6 +352,7 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, + use_max=params.use_max, ) for hyp in hyp_tokens: hyps.append([word_table[i] for i in hyp]) @@ -563,7 +565,10 @@ def main(): params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" - params.suffix += f"-use-max-{params.use_max}" + if params.use_LG: + params.suffix += f"-use-max-{params.use_max}" + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" elif "beam_search" in params.decoding_method: params.suffix += ( f"-{params.decoding_method}-beam-size-{params.beam_size}" @@ -632,8 +637,10 @@ def main(): if params.use_LG: lexicon = Lexicon(params.lang_dir) word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/LG.pt", map_location=device) + torch.load(lg_filename, map_location=device) ) decoding_graph.scores *= params.ngram_lm_scale else: