Add epsilon self-loops to G.

This commit is contained in:
Fangjun Kuang 2022-05-16 21:45:20 +08:00
parent 9ffc77a0f2
commit 3d833d9430
2 changed files with 10 additions and 2 deletions

View File

@ -389,6 +389,7 @@ def fast_beam_search_with_nbest_rescoring(
ret_arc_maps=False, ret_arc_maps=False,
) )
rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas)
rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas))
ngram_lm_scores = rescored_word_fsas.get_tot_scores( ngram_lm_scores = rescored_word_fsas.get_tot_scores(
use_double_scores=True, use_double_scores=True,

View File

@ -399,8 +399,10 @@ def decode_one_batch(
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_with_nbest_rescoring": elif params.decoding_method == "fast_beam_search_with_nbest_rescoring":
ngram_lm_scale_list = [-0.3, -0.2, -0.1, -0.05, -0.02, 0] ngram_lm_scale_list = [-0.5, -0.2, -0.1, -0.05, -0.02, 0]
ngram_lm_scale_list += [0.01, 0.02, 0.05] ngram_lm_scale_list += [0.01, 0.02, 0.05]
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.8]
ngram_lm_scale_list += [1.0, 1.5, 2.5, 3]
hyp_tokens = fast_beam_search_with_nbest_rescoring( hyp_tokens = fast_beam_search_with_nbest_rescoring(
model=model, model=model,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
@ -676,6 +678,8 @@ def load_ngram_LM(
logging.info(f"Loading pre-compiled {pt_file}") logging.info(f"Loading pre-compiled {pt_file}")
d = torch.load(pt_file, map_location=device) d = torch.load(pt_file, map_location=device)
G = k2.Fsa.from_dict(d) G = k2.Fsa.from_dict(d)
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
return G return G
txt_file = lm_dir / "G_4_gram.fst.txt" txt_file = lm_dir / "G_4_gram.fst.txt"
@ -702,7 +706,6 @@ def load_ngram_LM(
G.__dict__["_properties"] = None G.__dict__["_properties"] = None
G = k2.Fsa.from_fsas([G]).to(device) G = k2.Fsa.from_fsas([G]).to(device)
G = k2.arc_sort(G)
# Save a dummy value so that it can be loaded in C++. # Save a dummy value so that it can be loaded in C++.
# See https://github.com/pytorch/pytorch/issues/67902 # See https://github.com/pytorch/pytorch/issues/67902
@ -711,6 +714,9 @@ def load_ngram_LM(
logging.info(f"Saving to {pt_file} for later use") logging.info(f"Saving to {pt_file} for later use")
torch.save(G.as_dict(), pt_file) torch.save(G.as_dict(), pt_file)
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
return G return G
@ -840,6 +846,7 @@ def main():
word_table=word_table, word_table=word_table,
device=device, device=device,
) )
logging.info(f"G properties_str: {G.properties_str}")
else: else:
word_table = None word_table = None
G = None G = None