Generate k2 graph.

This commit is contained in:
Fangjun Kuang 2022-01-08 06:59:23 +08:00
parent 5ff660c84d
commit fa843181af

View File

@ -61,13 +61,10 @@ def compile_LG(lang_dir: str) -> kaldifst.StdVectorFst:
""" """
tokens = kaldifst.SymbolTable.read_text(f"{lang_dir}/tokens.txt") tokens = kaldifst.SymbolTable.read_text(f"{lang_dir}/tokens.txt")
words = kaldifst.SymbolTable.read_text(f"{lang_dir}/words.txt")
assert "#0" in tokens assert "#0" in tokens
assert "#0" in words
token_disambig_id = tokens.find("#0") first_token_disambig_id = tokens.find("#0")
word_disambig_id = words.find("#0")
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
L = k2_to_openfst(L, olabels="aux_labels") L = k2_to_openfst(L, olabels="aux_labels")
@ -93,22 +90,34 @@ def compile_LG(lang_dir: str) -> kaldifst.StdVectorFst:
connect=True, connect=True,
) )
logging.info("Determinize star LG") if True:
kaldifst.determinize_star(LG) logging.info("Determinize star LG")
kaldifst.determinize_star(LG)
logging.info("minimizeencoded") logging.info("minimizeencoded")
kaldifst.minimize_encoded(LG) kaldifst.minimize_encoded(LG)
else:
# You can use this branch to compare the size of
# the resulting graph
logging.info("Determinize LG")
LG = kaldifst.determinize(LG)
# Set all disambig IDs to eps LG = k2.Fsa.from_openfst(LG.to_str(is_acceptor=False), acceptor=False)
for state in kaldifst.StateIterator(LG):
for arc in kaldifst.ArcIterator(LG, state):
if arc.ilabel >= token_disambig_id:
arc.ilabel = 0
if arc.olabel >= word_disambig_id: LG.labels[LG.labels >= first_token_disambig_id] = 0
arc.olabel = 0
# reset properties as we changed the arc labels above # We only need the labels of LG during beam search decoding
LG.properties(0xFFFFFFFF, True) del LG.aux_labels
LG = k2.remove_epsilon(LG)
logging.info(
f"LG shape after k2.remove_epsilon: {LG.shape}, num_arcs: {LG.num_arcs}"
)
LG = k2.connect(LG)
logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)
return LG return LG
@ -116,7 +125,7 @@ def compile_LG(lang_dir: str) -> kaldifst.StdVectorFst:
def main(): def main():
args = get_args() args = get_args()
lang_dir = Path(args.lang_dir) lang_dir = Path(args.lang_dir)
out_filename = lang_dir / "LG.fst" out_filename = lang_dir / "LG.pt"
if out_filename.is_file(): if out_filename.is_file():
logging.info(f"{out_filename} already exists - skipping") logging.info(f"{out_filename} already exists - skipping")
@ -126,7 +135,7 @@ def main():
LG = compile_LG(lang_dir) LG = compile_LG(lang_dir)
logging.info(f"Saving LG to {out_filename}") logging.info(f"Saving LG to {out_filename}")
LG.write(str(out_filename)) torch.save(LG.as_dict(), out_filename)
if __name__ == "__main__": if __name__ == "__main__":