mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Update timit recipe
This commit is contained in:
parent
e023a9df98
commit
4beb25c50b
@ -54,7 +54,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
|
|||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
lang_dir:
|
lang_dir:
|
||||||
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
|
The language directory, e.g., data/lang_phone.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
An FSA representing HLG.
|
An FSA representing HLG.
|
||||||
@ -63,18 +63,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
|
|||||||
max_token_id = max(lexicon.tokens)
|
max_token_id = max(lexicon.tokens)
|
||||||
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
|
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
|
||||||
H = k2.ctc_topo(max_token_id)
|
H = k2.ctc_topo(max_token_id)
|
||||||
|
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||||
if Path(lang_dir / "L_disambig.pt").is_file():
|
|
||||||
logging.info("Loading L_disambig.pt")
|
|
||||||
d = torch.load(Path(lang_dir/"L_disambig.pt"))
|
|
||||||
L = k2.Fsa.from_dict(d)
|
|
||||||
else:
|
|
||||||
logging.info("Loading L_disambig.fst.txt")
|
|
||||||
with open(Path(lang_dir/"L_disambig.fst.txt")) as f:
|
|
||||||
L = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
|
||||||
torch.save(L_disambig.as_dict(), Path(lang_dir / "L_disambig.pt"))
|
|
||||||
|
|
||||||
#L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
|
||||||
|
|
||||||
if Path("data/lm/G.pt").is_file():
|
if Path("data/lm/G.pt").is_file():
|
||||||
logging.info("Loading pre-compiled G")
|
logging.info("Loading pre-compiled G")
|
||||||
|
@ -106,7 +106,7 @@ def get_tokens(lexicon: Lexicon) -> List[str]:
|
|||||||
ans = set()
|
ans = set()
|
||||||
for _, tokens in lexicon:
|
for _, tokens in lexicon:
|
||||||
ans.update(tokens)
|
ans.update(tokens)
|
||||||
#sorted_ans = sorted(list(ans))
|
|
||||||
sorted_ans = list(ans)
|
sorted_ans = list(ans)
|
||||||
return sorted_ans
|
return sorted_ans
|
||||||
|
|
||||||
@ -276,17 +276,10 @@ def lexicon_to_fst(
|
|||||||
next_state = 1 # the next un-allocated state, will be incremented as we go.
|
next_state = 1 # the next un-allocated state, will be incremented as we go.
|
||||||
arcs = []
|
arcs = []
|
||||||
|
|
||||||
print('token2id ori: ', token2id)
|
|
||||||
print('word2id ori: ', word2id)
|
|
||||||
|
|
||||||
assert token2id["<eps>"] == 0
|
assert token2id["<eps>"] == 0
|
||||||
assert word2id["<eps>"] == 0
|
assert word2id["<eps>"] == 0
|
||||||
|
|
||||||
eps = 0
|
eps = 0
|
||||||
print('token2id new: ', token2id)
|
|
||||||
print('word2id new: ', word2id)
|
|
||||||
|
|
||||||
print(lexicon)
|
|
||||||
for word, tokens in lexicon:
|
for word, tokens in lexicon:
|
||||||
assert len(tokens) > 0, f"{word} has no pronunciations"
|
assert len(tokens) > 0, f"{word} has no pronunciations"
|
||||||
cur_state = loop_state
|
cur_state = loop_state
|
||||||
@ -306,7 +299,7 @@ def lexicon_to_fst(
|
|||||||
# the other one to the sil_state.
|
# the other one to the sil_state.
|
||||||
i = len(tokens) - 1
|
i = len(tokens) - 1
|
||||||
w = word if i == 0 else eps
|
w = word if i == 0 else eps
|
||||||
tokens[i] = tokens[i] if i >=0 else eps
|
tokens[i] = tokens[i] if i >= 0 else eps
|
||||||
arcs.append([cur_state, loop_state, tokens[i], w, score])
|
arcs.append([cur_state, loop_state, tokens[i], w, score])
|
||||||
|
|
||||||
if need_self_loops:
|
if need_self_loops:
|
||||||
@ -326,7 +319,6 @@ def lexicon_to_fst(
|
|||||||
arcs = [[str(i) for i in arc] for arc in arcs]
|
arcs = [[str(i) for i in arc] for arc in arcs]
|
||||||
arcs = [" ".join(arc) for arc in arcs]
|
arcs = [" ".join(arc) for arc in arcs]
|
||||||
arcs = "\n".join(arcs)
|
arcs = "\n".join(arcs)
|
||||||
print(arcs)
|
|
||||||
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||||
return fsa
|
return fsa
|
||||||
|
|
||||||
@ -334,7 +326,6 @@ def lexicon_to_fst(
|
|||||||
def main():
|
def main():
|
||||||
args = get_args()
|
args = get_args()
|
||||||
lang_dir = Path(args.lang_dir)
|
lang_dir = Path(args.lang_dir)
|
||||||
#out_dir = Path("data/lang_phone")
|
|
||||||
lexicon_filename = lang_dir / "lexicon.txt"
|
lexicon_filename = lang_dir / "lexicon.txt"
|
||||||
|
|
||||||
lexicon = read_lexicon(lexicon_filename)
|
lexicon = read_lexicon(lexicon_filename)
|
||||||
@ -386,7 +377,7 @@ def main():
|
|||||||
L.aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
L.aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||||
L_disambig.labels_sym = L.labels_sym
|
L_disambig.labels_sym = L.labels_sym
|
||||||
L_disambig.aux_labels_sym = L.aux_labels_sym
|
L_disambig.aux_labels_sym = L.aux_labels_sym
|
||||||
L.draw(out_dir / "L.png", title="L")
|
L.draw(lang_dir / "L.png", title="L")
|
||||||
L_disambig.draw(lang_dir / "L_disambig.png", title="L_disambig")
|
L_disambig.draw(lang_dir / "L_disambig.png", title="L_disambig")
|
||||||
|
|
||||||
|
|
||||||
|
@ -64,25 +64,22 @@ def prepare_lexicon(manifests_dir: str, lang_dir: str):
|
|||||||
lexicon = Path(lang_dir) / "lexicon.txt"
|
lexicon = Path(lang_dir) / "lexicon.txt"
|
||||||
|
|
||||||
logging.info(f"Loading {supervisions_train}!")
|
logging.info(f"Loading {supervisions_train}!")
|
||||||
with open(supervisions_train, 'r') as load_f:
|
with open(supervisions_train, "r") as load_f:
|
||||||
load_dicts = json.load(load_f)
|
load_dicts = json.load(load_f)
|
||||||
for load_dict in load_dicts:
|
for load_dict in load_dicts:
|
||||||
idx = load_dict['id']
|
text = load_dict["text"]
|
||||||
text = load_dict['text']
|
phones_list = list(filter(None, text.split(" ")))
|
||||||
phones_list = list(filter(None, text.split(' ')))
|
|
||||||
|
|
||||||
for phone in phones_list:
|
for phone in phones_list:
|
||||||
if phone not in phones:
|
if phone not in phones:
|
||||||
phones.append(phone)
|
phones.append(phone)
|
||||||
|
|
||||||
with open(lexicon, 'w') as f:
|
with open(lexicon, "w") as f:
|
||||||
for phone in sorted(phones):
|
for phone in sorted(phones):
|
||||||
f.write(str(phone) + " " + str(phone))
|
f.write(str(phone) + " " + str(phone))
|
||||||
f.write('\n')
|
f.write("\n")
|
||||||
f.write("<UNK> <UNK>")
|
f.write("<UNK> <UNK>")
|
||||||
f.write('\n')
|
f.write("\n")
|
||||||
|
|
||||||
return lexicon
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -90,9 +87,8 @@ def main():
|
|||||||
manifests_dir = Path(args.manifests_dir)
|
manifests_dir = Path(args.manifests_dir)
|
||||||
lang_dir = Path(args.lang_dir)
|
lang_dir = Path(args.lang_dir)
|
||||||
|
|
||||||
logging.info(f"Generating lexicon.txt and train.text")
|
logging.info("Generating lexicon.txt")
|
||||||
|
prepare_lexicon(manifests_dir, lang_dir)
|
||||||
lexicon_file = prepare_lexicon(manifests_dir, lang_dir)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -103,4 +99,3 @@ if __name__ == "__main__":
|
|||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
@ -311,26 +311,20 @@ class TimitAsrDataModule(DataModule):
|
|||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_cuts(self) -> CutSet:
|
def train_cuts(self) -> CutSet:
|
||||||
logging.info("About to get train cuts")
|
logging.info("About to get train cuts")
|
||||||
cuts_train = load_manifest(
|
cuts_train = load_manifest(self.args.feature_dir / "cuts_TRAIN.json.gz")
|
||||||
self.args.feature_dir / "cuts_TRAIN.json.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
return cuts_train
|
return cuts_train
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def valid_cuts(self) -> CutSet:
|
def valid_cuts(self) -> CutSet:
|
||||||
logging.info("About to get dev cuts")
|
logging.info("About to get dev cuts")
|
||||||
cuts_valid = load_manifest(
|
cuts_valid = load_manifest(self.args.feature_dir / "cuts_DEV.json.gz")
|
||||||
self.args.feature_dir / "cuts_DEV.json.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
return cuts_valid
|
return cuts_valid
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_cuts(self) -> CutSet:
|
def test_cuts(self) -> CutSet:
|
||||||
logging.debug("About to get test cuts")
|
logging.debug("About to get test cuts")
|
||||||
cuts_test = load_manifest(
|
cuts_test = load_manifest(self.args.feature_dir / "cuts_TEST.json.gz")
|
||||||
self.args.feature_dir / "cuts_TEST.json.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
return cuts_test
|
return cuts_test
|
||||||
|
@ -449,7 +449,6 @@ def main():
|
|||||||
)
|
)
|
||||||
if params.avg == 1:
|
if params.avg == 1:
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
#load_checkpoint(f"tmp/icefall_asr_librispeech_tdnn-lstm_ctc/exp/pretrained.pt", model)
|
|
||||||
else:
|
else:
|
||||||
start = params.epoch - params.avg + 1
|
start = params.epoch - params.avg + 1
|
||||||
filenames = []
|
filenames = []
|
||||||
@ -470,18 +469,6 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
timit = TimitAsrDataModule(args)
|
timit = TimitAsrDataModule(args)
|
||||||
# CAUTION: `test_sets` is for displaying only.
|
|
||||||
# If you want to skip test-clean, you have to skip
|
|
||||||
# it inside the for loop. That is, use
|
|
||||||
#
|
|
||||||
# if test_set == 'test-clean': continue
|
|
||||||
#
|
|
||||||
#test_sets = ["test-clean", "test-other"]
|
|
||||||
#test_sets = ["test-other"]
|
|
||||||
#for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
|
|
||||||
#if test_set == "test-clean": continue
|
|
||||||
#if test_set == "test-other": break
|
|
||||||
test_set = "TEST"
|
|
||||||
test_dl = timit.test_dataloaders()
|
test_dl = timit.test_dataloaders()
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
@ -491,7 +478,7 @@ def main():
|
|||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
G=G,
|
G=G,
|
||||||
)
|
)
|
||||||
|
test_set = "TEST"
|
||||||
save_results(
|
save_results(
|
||||||
params=params, test_set_name=test_set, results_dict=results_dict
|
params=params, test_set_name=test_set, results_dict=results_dict
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user