mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Add epsilon self-loops to G.
This commit is contained in:
parent
9ffc77a0f2
commit
3d833d9430
@ -389,6 +389,7 @@ def fast_beam_search_with_nbest_rescoring(
|
|||||||
ret_arc_maps=False,
|
ret_arc_maps=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas)
|
||||||
rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas))
|
rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas))
|
||||||
ngram_lm_scores = rescored_word_fsas.get_tot_scores(
|
ngram_lm_scores = rescored_word_fsas.get_tot_scores(
|
||||||
use_double_scores=True,
|
use_double_scores=True,
|
||||||
|
@ -399,8 +399,10 @@ def decode_one_batch(
|
|||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
elif params.decoding_method == "fast_beam_search_with_nbest_rescoring":
|
elif params.decoding_method == "fast_beam_search_with_nbest_rescoring":
|
||||||
ngram_lm_scale_list = [-0.3, -0.2, -0.1, -0.05, -0.02, 0]
|
ngram_lm_scale_list = [-0.5, -0.2, -0.1, -0.05, -0.02, 0]
|
||||||
ngram_lm_scale_list += [0.01, 0.02, 0.05]
|
ngram_lm_scale_list += [0.01, 0.02, 0.05]
|
||||||
|
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.8]
|
||||||
|
ngram_lm_scale_list += [1.0, 1.5, 2.5, 3]
|
||||||
hyp_tokens = fast_beam_search_with_nbest_rescoring(
|
hyp_tokens = fast_beam_search_with_nbest_rescoring(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
@ -676,6 +678,8 @@ def load_ngram_LM(
|
|||||||
logging.info(f"Loading pre-compiled {pt_file}")
|
logging.info(f"Loading pre-compiled {pt_file}")
|
||||||
d = torch.load(pt_file, map_location=device)
|
d = torch.load(pt_file, map_location=device)
|
||||||
G = k2.Fsa.from_dict(d)
|
G = k2.Fsa.from_dict(d)
|
||||||
|
G = k2.add_epsilon_self_loops(G)
|
||||||
|
G = k2.arc_sort(G)
|
||||||
return G
|
return G
|
||||||
|
|
||||||
txt_file = lm_dir / "G_4_gram.fst.txt"
|
txt_file = lm_dir / "G_4_gram.fst.txt"
|
||||||
@ -702,7 +706,6 @@ def load_ngram_LM(
|
|||||||
G.__dict__["_properties"] = None
|
G.__dict__["_properties"] = None
|
||||||
|
|
||||||
G = k2.Fsa.from_fsas([G]).to(device)
|
G = k2.Fsa.from_fsas([G]).to(device)
|
||||||
G = k2.arc_sort(G)
|
|
||||||
|
|
||||||
# Save a dummy value so that it can be loaded in C++.
|
# Save a dummy value so that it can be loaded in C++.
|
||||||
# See https://github.com/pytorch/pytorch/issues/67902
|
# See https://github.com/pytorch/pytorch/issues/67902
|
||||||
@ -711,6 +714,9 @@ def load_ngram_LM(
|
|||||||
|
|
||||||
logging.info(f"Saving to {pt_file} for later use")
|
logging.info(f"Saving to {pt_file} for later use")
|
||||||
torch.save(G.as_dict(), pt_file)
|
torch.save(G.as_dict(), pt_file)
|
||||||
|
|
||||||
|
G = k2.add_epsilon_self_loops(G)
|
||||||
|
G = k2.arc_sort(G)
|
||||||
return G
|
return G
|
||||||
|
|
||||||
|
|
||||||
@ -840,6 +846,7 @@ def main():
|
|||||||
word_table=word_table,
|
word_table=word_table,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
logging.info(f"G properties_str: {G.properties_str}")
|
||||||
else:
|
else:
|
||||||
word_table = None
|
word_table = None
|
||||||
G = None
|
G = None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user