diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 44fc34640..df260f9ba 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -111,6 +111,7 @@ from beam_search import ( fast_beam_search_nbest_LG, fast_beam_search_nbest_oracle, fast_beam_search_one_best, + fast_beam_search_with_nbest_rescoring, greedy_search, greedy_search_batch, modified_beam_search, @@ -312,6 +313,35 @@ def get_parser(): help="left context can be seen during decoding (in frames after subsampling)", ) + parser.add_argument( + "--temperature", + type=float, + default=1.0, + help="""Softmax temperature. + The output of the model is (logits / temperature).log_softmax(). + """, + ) + + parser.add_argument( + "--lm-dir", + type=Path, + default=Path("./data/lm"), + help="""Used only when --decoding-method is + fast_beam_search_with_nbest_rescoring. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + parser.add_argument( + "--words-txt", + type=Path, + default=Path("./data/lang_bpe_500/words.txt"), + help="""Used only when --decoding-method is + fast_beam_search_with_nbest_rescoring. + It is the word table. + """, + ) + add_model_arguments(parser) return parser @@ -324,6 +354,7 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + G: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -352,6 +383,11 @@ def decode_one_batch( The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + G: + Optional. Used only when decoding method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_oracle, + or fast_beam_search_with_nbest_rescoring. + It an FsaVec containing an acceptor. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -397,6 +433,7 @@ def decode_one_batch( beam=params.beam, max_contexts=params.max_contexts, max_states=params.max_states, + temperature=params.temperature, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -411,6 +448,7 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, + temperature=params.temperature, ) for hyp in hyp_tokens: hyps.append([word_table[i] for i in hyp]) @@ -425,6 +463,7 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, + temperature=params.temperature, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -440,6 +479,7 @@ def decode_one_batch( num_paths=params.num_paths, ref_texts=sp.encode(supervisions["text"]), nbest_scale=params.nbest_scale, + temperature=params.temperature, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -460,9 +500,32 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, + temperature=params.temperature, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_with_nbest_rescoring": + ngram_lm_scale_list = [-0.5, -0.2, -0.1, -0.05, -0.02, 0] + ngram_lm_scale_list += [0.01, 0.02, 0.05] + ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.8] + ngram_lm_scale_list += [1.0, 1.5, 2.5, 3] + hyp_tokens = fast_beam_search_with_nbest_rescoring( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ngram_lm_scale_list=ngram_lm_scale_list, + num_paths=params.num_paths, + G=G, + sp=sp, + word_table=word_table, + use_double_scores=True, + nbest_scale=params.nbest_scale, + temperature=params.temperature, + ) else: batch_size = encoder_out.size(0) @@ -496,6 +559,7 @@ def decode_one_batch( f"beam_{params.beam}_" f"max_contexts_{params.max_contexts}_" f"max_states_{params.max_states}" + f"temperature_{params.temperature}" ): hyps } elif params.decoding_method == "fast_beam_search": @@ -504,8 +568,23 @@ def decode_one_batch( f"beam_{params.beam}_" f"max_contexts_{params.max_contexts}_" f"max_states_{params.max_states}" + f"temperature_{params.temperature}" ): hyps } + elif params.decoding_method == "fast_beam_search_with_nbest_rescoring": + prefix = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}_" + f"num_paths_{params.num_paths}_" + f"nbest_scale_{params.nbest_scale}_" + f"temperature_{params.temperature}_" + ) + ans: Dict[str, List[List[str]]] = {} + for key, hyp in hyp_tokens.items(): + t: List[str] = sp.decode(hyp) + ans[prefix + key] = [s.split() for s in t] + return ans elif "fast_beam_search" in params.decoding_method: key = f"beam_{params.beam}_" key += f"max_contexts_{params.max_contexts}_" @@ -518,7 +597,12 @@ def decode_one_batch( return {key: hyps} else: - return {f"beam_size_{params.beam_size}": hyps} + return { + ( + f"beam_size_{params.beam_size}_" + f"temperature_{params.temperature}" + ): hyps + } def decode_dataset( @@ -528,6 +612,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + G: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -546,6 +631,11 @@ def decode_dataset( The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + G: + Optional. Used only when decoding method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_oracle, + or fast_beam_search_with_nbest_rescoring. + It's an FsaVec containing an acceptor. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -576,6 +666,7 @@ def decode_dataset( word_table=word_table, decoding_graph=decoding_graph, batch=batch, + G=G, ) for name, hyps in hyps_dict.items(): @@ -642,6 +733,71 @@ def save_results( logging.info(s) +def load_ngram_LM( + lm_dir: Path, word_table: k2.SymbolTable, device: torch.device +) -> k2.Fsa: + """Read a ngram model from the given directory. + Args: + lm_dir: + It should contain either G_4_gram.pt or G_4_gram.fst.txt + word_table: + The word table mapping words to IDs and vice versa. + device: + The resulting FSA will be moved to this device. + Returns: + Return an FsaVec containing a single acceptor. + """ + lm_dir = Path(lm_dir) + assert lm_dir.is_dir(), f"{lm_dir} does not exist" + + pt_file = lm_dir / "G_4_gram.pt" + + if pt_file.is_file(): + logging.info(f"Loading pre-compiled {pt_file}") + d = torch.load(pt_file, map_location=device) + G = k2.Fsa.from_dict(d) + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + return G + + txt_file = lm_dir / "G_4_gram.fst.txt" + + assert txt_file.is_file(), f"{txt_file} does not exist" + logging.info(f"Loading {txt_file}") + logging.warning("It may take 8 minutes (Will be cached for later use).") + with open(txt_file) as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # Now G is an acceptor + + first_word_disambig_id = word_table["#0"] + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + + G = k2.Fsa.from_fsas([G]).to(device) + + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + logging.info(f"Saving to {pt_file} for later use") + torch.save(G.as_dict(), pt_file) + + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + return G + + @torch.no_grad() def main(): parser = get_parser() @@ -660,6 +816,7 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", + "fast_beam_search_with_nbest_rescoring", ) params.res_dir = params.exp_dir / params.decoding_method @@ -760,6 +917,19 @@ def main(): torch.load(lg_filename, map_location=device) ) decoding_graph.scores *= params.ngram_lm_scale + elif params.decoding_method == "fast_beam_search_with_nbest_rescoring": + logging.info(f"Loading word symbol table from {params.words_txt}") + word_table = k2.SymbolTable.from_file(params.words_txt) + + G = load_ngram_LM( + lm_dir=params.lm_dir, + word_table=word_table, + device=device, + ) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) + logging.info(f"G properties_str: {G.properties_str}") else: word_table = None decoding_graph = k2.trivial_graph( @@ -792,6 +962,7 @@ def main(): sp=sp, word_table=word_table, decoding_graph=decoding_graph, + G=G, ) save_results(