mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 10:44:19 +00:00
Generate k2 graph.
This commit is contained in:
parent
5ff660c84d
commit
fa843181af
@ -61,13 +61,10 @@ def compile_LG(lang_dir: str) -> kaldifst.StdVectorFst:
|
||||
"""
|
||||
|
||||
tokens = kaldifst.SymbolTable.read_text(f"{lang_dir}/tokens.txt")
|
||||
words = kaldifst.SymbolTable.read_text(f"{lang_dir}/words.txt")
|
||||
|
||||
assert "#0" in tokens
|
||||
assert "#0" in words
|
||||
|
||||
token_disambig_id = tokens.find("#0")
|
||||
word_disambig_id = words.find("#0")
|
||||
first_token_disambig_id = tokens.find("#0")
|
||||
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
L = k2_to_openfst(L, olabels="aux_labels")
|
||||
@ -93,22 +90,34 @@ def compile_LG(lang_dir: str) -> kaldifst.StdVectorFst:
|
||||
connect=True,
|
||||
)
|
||||
|
||||
if True:
|
||||
logging.info("Determinize star LG")
|
||||
kaldifst.determinize_star(LG)
|
||||
|
||||
logging.info("minimizeencoded")
|
||||
kaldifst.minimize_encoded(LG)
|
||||
else:
|
||||
# You can use this branch to compare the size of
|
||||
# the resulting graph
|
||||
logging.info("Determinize LG")
|
||||
LG = kaldifst.determinize(LG)
|
||||
|
||||
# Set all disambig IDs to eps
|
||||
for state in kaldifst.StateIterator(LG):
|
||||
for arc in kaldifst.ArcIterator(LG, state):
|
||||
if arc.ilabel >= token_disambig_id:
|
||||
arc.ilabel = 0
|
||||
LG = k2.Fsa.from_openfst(LG.to_str(is_acceptor=False), acceptor=False)
|
||||
|
||||
if arc.olabel >= word_disambig_id:
|
||||
arc.olabel = 0
|
||||
# reset properties as we changed the arc labels above
|
||||
LG.properties(0xFFFFFFFF, True)
|
||||
LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||
|
||||
# We only need the labels of LG during beam search decoding
|
||||
del LG.aux_labels
|
||||
|
||||
LG = k2.remove_epsilon(LG)
|
||||
logging.info(
|
||||
f"LG shape after k2.remove_epsilon: {LG.shape}, num_arcs: {LG.num_arcs}"
|
||||
)
|
||||
|
||||
LG = k2.connect(LG)
|
||||
|
||||
logging.info("Arc sorting LG")
|
||||
LG = k2.arc_sort(LG)
|
||||
|
||||
return LG
|
||||
|
||||
@ -116,7 +125,7 @@ def compile_LG(lang_dir: str) -> kaldifst.StdVectorFst:
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
out_filename = lang_dir / "LG.fst"
|
||||
out_filename = lang_dir / "LG.pt"
|
||||
|
||||
if out_filename.is_file():
|
||||
logging.info(f"{out_filename} already exists - skipping")
|
||||
@ -126,7 +135,7 @@ def main():
|
||||
|
||||
LG = compile_LG(lang_dir)
|
||||
logging.info(f"Saving LG to {out_filename}")
|
||||
LG.write(str(out_filename))
|
||||
torch.save(LG.as_dict(), out_filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user