icefall/egs/librispeech/ASR/local/compile_hlg.py
2021-07-24 17:13:20 +08:00

83 lines
2.0 KiB
Python

#!/usr/bin/env python3
"""
This script compiles HLG from
- H, the ctc topology, built from phones contained in data/lang/lexicon.txt
- L, the lexicon, built from data/lang/L_disambig.pt
Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from data/lm/G_3_gram.fst.txt
The generated HLG is saved in data/lm/HLG.pt
"""
import k2
import torch
from icefall.lexicon import Lexicon
def main():
lexicon = Lexicon("data/lang")
max_token_id = max(lexicon.tokens)
H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load("data/lang/L_disambig.pt"))
with open("data/lm/G_3_gram.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
first_token_disambig_id = lexicon.phones["#0"]
first_word_disambig_id = lexicon.words["#0"]
L = k2.arc_sort(L)
G = k2.arc_sort(G)
print("Intersecting L and G")
LG = k2.compose(L, G)
print(f"LG shape: {LG.shape}")
print("Connecting LG")
LG = k2.connect(LG)
print(f"LG shape after k2.connect: {LG.shape}")
print(type(LG.aux_labels))
print("Determinizing LG")
LG = k2.determinize(LG)
print(type(LG.aux_labels))
print("Connecting LG after k2.determinize")
LG = k2.connect(LG)
print("Removing disambiguation symbols on LG")
LG.labels[LG.labels >= first_token_disambig_id] = 0
assert isinstance(LG.aux_labels, k2.RaggedInt)
LG.aux_labels.values()[LG.aux_labels.values() >= first_word_disambig_id] = 0
LG = k2.remove_epsilon(LG)
print(f"LG shape after k2.remove_epsilon: {LG.shape}")
LG = k2.connect(LG)
LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0)
print("Arc sorting LG")
LG = k2.arc_sort(LG)
print("Composing H and LG")
HLG = k2.compose(H, LG, inner_labels="phones")
print("Connecting LG")
HLG = k2.connect(HLG)
print("Arc sorting LG")
HLG = k2.arc_sort(HLG)
print("Saving HLG.pt to data/lm")
torch.save(HLG.as_dict(), "data/lm/HLG.pt")
if __name__ == "__main__":
main()