From fa843181aff5be214ee499a7e1bfefc9519584f5 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 8 Jan 2022 06:59:23 +0800 Subject: [PATCH] Generate k2 graph. --- egs/librispeech/ASR/local/compile_lg.py | 47 +++++++++++++++---------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/egs/librispeech/ASR/local/compile_lg.py b/egs/librispeech/ASR/local/compile_lg.py index e95bf77a1..cb268370c 100644 --- a/egs/librispeech/ASR/local/compile_lg.py +++ b/egs/librispeech/ASR/local/compile_lg.py @@ -61,13 +61,10 @@ def compile_LG(lang_dir: str) -> kaldifst.StdVectorFst: """ tokens = kaldifst.SymbolTable.read_text(f"{lang_dir}/tokens.txt") - words = kaldifst.SymbolTable.read_text(f"{lang_dir}/words.txt") assert "#0" in tokens - assert "#0" in words - token_disambig_id = tokens.find("#0") - word_disambig_id = words.find("#0") + first_token_disambig_id = tokens.find("#0") L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) L = k2_to_openfst(L, olabels="aux_labels") @@ -93,22 +90,34 @@ def compile_LG(lang_dir: str) -> kaldifst.StdVectorFst: connect=True, ) - logging.info("Determinize star LG") - kaldifst.determinize_star(LG) + if True: + logging.info("Determinize star LG") + kaldifst.determinize_star(LG) - logging.info("minimizeencoded") - kaldifst.minimize_encoded(LG) + logging.info("minimizeencoded") + kaldifst.minimize_encoded(LG) + else: + # You can use this branch to compare the size of + # the resulting graph + logging.info("Determinize LG") + LG = kaldifst.determinize(LG) - # Set all disambig IDs to eps - for state in kaldifst.StateIterator(LG): - for arc in kaldifst.ArcIterator(LG, state): - if arc.ilabel >= token_disambig_id: - arc.ilabel = 0 + LG = k2.Fsa.from_openfst(LG.to_str(is_acceptor=False), acceptor=False) - if arc.olabel >= word_disambig_id: - arc.olabel = 0 - # reset properties as we changed the arc labels above - LG.properties(0xFFFFFFFF, True) + LG.labels[LG.labels >= first_token_disambig_id] = 0 + + # We only need the labels of LG during beam search decoding + del LG.aux_labels + + LG = k2.remove_epsilon(LG) + logging.info( + f"LG shape after k2.remove_epsilon: {LG.shape}, num_arcs: {LG.num_arcs}" + ) + + LG = k2.connect(LG) + + logging.info("Arc sorting LG") + LG = k2.arc_sort(LG) return LG @@ -116,7 +125,7 @@ def compile_LG(lang_dir: str) -> kaldifst.StdVectorFst: def main(): args = get_args() lang_dir = Path(args.lang_dir) - out_filename = lang_dir / "LG.fst" + out_filename = lang_dir / "LG.pt" if out_filename.is_file(): logging.info(f"{out_filename} already exists - skipping") @@ -126,7 +135,7 @@ def main(): LG = compile_LG(lang_dir) logging.info(f"Saving LG to {out_filename}") - LG.write(str(out_filename)) + torch.save(LG.as_dict(), out_filename) if __name__ == "__main__":