mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 18:54:18 +00:00
Generate k2 graph.
This commit is contained in:
parent
5ff660c84d
commit
fa843181af
@ -61,13 +61,10 @@ def compile_LG(lang_dir: str) -> kaldifst.StdVectorFst:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
tokens = kaldifst.SymbolTable.read_text(f"{lang_dir}/tokens.txt")
|
tokens = kaldifst.SymbolTable.read_text(f"{lang_dir}/tokens.txt")
|
||||||
words = kaldifst.SymbolTable.read_text(f"{lang_dir}/words.txt")
|
|
||||||
|
|
||||||
assert "#0" in tokens
|
assert "#0" in tokens
|
||||||
assert "#0" in words
|
|
||||||
|
|
||||||
token_disambig_id = tokens.find("#0")
|
first_token_disambig_id = tokens.find("#0")
|
||||||
word_disambig_id = words.find("#0")
|
|
||||||
|
|
||||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||||
L = k2_to_openfst(L, olabels="aux_labels")
|
L = k2_to_openfst(L, olabels="aux_labels")
|
||||||
@ -93,22 +90,34 @@ def compile_LG(lang_dir: str) -> kaldifst.StdVectorFst:
|
|||||||
connect=True,
|
connect=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Determinize star LG")
|
if True:
|
||||||
kaldifst.determinize_star(LG)
|
logging.info("Determinize star LG")
|
||||||
|
kaldifst.determinize_star(LG)
|
||||||
|
|
||||||
logging.info("minimizeencoded")
|
logging.info("minimizeencoded")
|
||||||
kaldifst.minimize_encoded(LG)
|
kaldifst.minimize_encoded(LG)
|
||||||
|
else:
|
||||||
|
# You can use this branch to compare the size of
|
||||||
|
# the resulting graph
|
||||||
|
logging.info("Determinize LG")
|
||||||
|
LG = kaldifst.determinize(LG)
|
||||||
|
|
||||||
# Set all disambig IDs to eps
|
LG = k2.Fsa.from_openfst(LG.to_str(is_acceptor=False), acceptor=False)
|
||||||
for state in kaldifst.StateIterator(LG):
|
|
||||||
for arc in kaldifst.ArcIterator(LG, state):
|
|
||||||
if arc.ilabel >= token_disambig_id:
|
|
||||||
arc.ilabel = 0
|
|
||||||
|
|
||||||
if arc.olabel >= word_disambig_id:
|
LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||||
arc.olabel = 0
|
|
||||||
# reset properties as we changed the arc labels above
|
# We only need the labels of LG during beam search decoding
|
||||||
LG.properties(0xFFFFFFFF, True)
|
del LG.aux_labels
|
||||||
|
|
||||||
|
LG = k2.remove_epsilon(LG)
|
||||||
|
logging.info(
|
||||||
|
f"LG shape after k2.remove_epsilon: {LG.shape}, num_arcs: {LG.num_arcs}"
|
||||||
|
)
|
||||||
|
|
||||||
|
LG = k2.connect(LG)
|
||||||
|
|
||||||
|
logging.info("Arc sorting LG")
|
||||||
|
LG = k2.arc_sort(LG)
|
||||||
|
|
||||||
return LG
|
return LG
|
||||||
|
|
||||||
@ -116,7 +125,7 @@ def compile_LG(lang_dir: str) -> kaldifst.StdVectorFst:
|
|||||||
def main():
|
def main():
|
||||||
args = get_args()
|
args = get_args()
|
||||||
lang_dir = Path(args.lang_dir)
|
lang_dir = Path(args.lang_dir)
|
||||||
out_filename = lang_dir / "LG.fst"
|
out_filename = lang_dir / "LG.pt"
|
||||||
|
|
||||||
if out_filename.is_file():
|
if out_filename.is_file():
|
||||||
logging.info(f"{out_filename} already exists - skipping")
|
logging.info(f"{out_filename} already exists - skipping")
|
||||||
@ -126,7 +135,7 @@ def main():
|
|||||||
|
|
||||||
LG = compile_LG(lang_dir)
|
LG = compile_LG(lang_dir)
|
||||||
logging.info(f"Saving LG to {out_filename}")
|
logging.info(f"Saving LG to {out_filename}")
|
||||||
LG.write(str(out_filename))
|
torch.save(LG.as_dict(), out_filename)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user