Generate k2 graph.

This commit is contained in:
Fangjun Kuang 2022-01-08 06:59:23 +08:00
parent 5ff660c84d
commit fa843181af

View File

@ -61,13 +61,10 @@ def compile_LG(lang_dir: str) -> kaldifst.StdVectorFst:
"""
tokens = kaldifst.SymbolTable.read_text(f"{lang_dir}/tokens.txt")
words = kaldifst.SymbolTable.read_text(f"{lang_dir}/words.txt")
assert "#0" in tokens
assert "#0" in words
token_disambig_id = tokens.find("#0")
word_disambig_id = words.find("#0")
first_token_disambig_id = tokens.find("#0")
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
L = k2_to_openfst(L, olabels="aux_labels")
@ -93,22 +90,34 @@ def compile_LG(lang_dir: str) -> kaldifst.StdVectorFst:
connect=True,
)
if True:
logging.info("Determinize star LG")
kaldifst.determinize_star(LG)
logging.info("minimizeencoded")
kaldifst.minimize_encoded(LG)
else:
# You can use this branch to compare the size of
# the resulting graph
logging.info("Determinize LG")
LG = kaldifst.determinize(LG)
# Set all disambig IDs to eps
for state in kaldifst.StateIterator(LG):
for arc in kaldifst.ArcIterator(LG, state):
if arc.ilabel >= token_disambig_id:
arc.ilabel = 0
LG = k2.Fsa.from_openfst(LG.to_str(is_acceptor=False), acceptor=False)
if arc.olabel >= word_disambig_id:
arc.olabel = 0
# reset properties as we changed the arc labels above
LG.properties(0xFFFFFFFF, True)
LG.labels[LG.labels >= first_token_disambig_id] = 0
# We only need the labels of LG during beam search decoding
del LG.aux_labels
LG = k2.remove_epsilon(LG)
logging.info(
f"LG shape after k2.remove_epsilon: {LG.shape}, num_arcs: {LG.num_arcs}"
)
LG = k2.connect(LG)
logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)
return LG
@ -116,7 +125,7 @@ def compile_LG(lang_dir: str) -> kaldifst.StdVectorFst:
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
out_filename = lang_dir / "LG.fst"
out_filename = lang_dir / "LG.pt"
if out_filename.is_file():
logging.info(f"{out_filename} already exists - skipping")
@ -126,7 +135,7 @@ def main():
LG = compile_LG(lang_dir)
logging.info(f"Saving LG to {out_filename}")
LG.write(str(out_filename))
torch.save(LG.as_dict(), out_filename)
if __name__ == "__main__":