From 3d833d9430640791382d5add3a0599a6f6bf4168 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 16 May 2022 21:45:20 +0800 Subject: [PATCH] Add epsilon self-loops to G. --- .../ASR/pruned_transducer_stateless2/beam_search.py | 1 + .../ASR/pruned_transducer_stateless3/decode.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index e49f20e6e..8b2570ee4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -389,6 +389,7 @@ def fast_beam_search_with_nbest_rescoring( ret_arc_maps=False, ) + rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) ngram_lm_scores = rescored_word_fsas.get_tot_scores( use_double_scores=True, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 2a76dd31a..0f7114457 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -399,8 +399,10 @@ def decode_one_batch( for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) elif params.decoding_method == "fast_beam_search_with_nbest_rescoring": - ngram_lm_scale_list = [-0.3, -0.2, -0.1, -0.05, -0.02, 0] + ngram_lm_scale_list = [-0.5, -0.2, -0.1, -0.05, -0.02, 0] ngram_lm_scale_list += [0.01, 0.02, 0.05] + ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.8] + ngram_lm_scale_list += [1.0, 1.5, 2.5, 3] hyp_tokens = fast_beam_search_with_nbest_rescoring( model=model, decoding_graph=decoding_graph, @@ -676,6 +678,8 @@ def load_ngram_LM( logging.info(f"Loading pre-compiled {pt_file}") d = torch.load(pt_file, map_location=device) G = k2.Fsa.from_dict(d) + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) return G txt_file = lm_dir / "G_4_gram.fst.txt" @@ -702,7 +706,6 @@ def load_ngram_LM( G.__dict__["_properties"] = None G = k2.Fsa.from_fsas([G]).to(device) - G = k2.arc_sort(G) # Save a dummy value so that it can be loaded in C++. # See https://github.com/pytorch/pytorch/issues/67902 @@ -711,6 +714,9 @@ def load_ngram_LM( logging.info(f"Saving to {pt_file} for later use") torch.save(G.as_dict(), pt_file) + + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) return G @@ -840,6 +846,7 @@ def main(): word_table=word_table, device=device, ) + logging.info(f"G properties_str: {G.properties_str}") else: word_table = None G = None