Update timit recipe

This commit is contained in:
Mingshuang Luo 2021-10-28 14:45:24 +08:00
parent e023a9df98
commit 4beb25c50b
5 changed files with 25 additions and 69 deletions

View File

@ -54,7 +54,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
""" """
Args: Args:
lang_dir: lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe_5000. The language directory, e.g., data/lang_phone.
Return: Return:
An FSA representing HLG. An FSA representing HLG.
@ -63,18 +63,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
max_token_id = max(lexicon.tokens) max_token_id = max(lexicon.tokens)
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id) H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
if Path(lang_dir / "L_disambig.pt").is_file():
logging.info("Loading L_disambig.pt")
d = torch.load(Path(lang_dir/"L_disambig.pt"))
L = k2.Fsa.from_dict(d)
else:
logging.info("Loading L_disambig.fst.txt")
with open(Path(lang_dir/"L_disambig.fst.txt")) as f:
L = k2.Fsa.from_openfst(f.read(), acceptor=False)
torch.save(L_disambig.as_dict(), Path(lang_dir / "L_disambig.pt"))
#L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
if Path("data/lm/G.pt").is_file(): if Path("data/lm/G.pt").is_file():
logging.info("Loading pre-compiled G") logging.info("Loading pre-compiled G")

View File

@ -106,7 +106,7 @@ def get_tokens(lexicon: Lexicon) -> List[str]:
ans = set() ans = set()
for _, tokens in lexicon: for _, tokens in lexicon:
ans.update(tokens) ans.update(tokens)
#sorted_ans = sorted(list(ans))
sorted_ans = list(ans) sorted_ans = list(ans)
return sorted_ans return sorted_ans
@ -276,17 +276,10 @@ def lexicon_to_fst(
next_state = 1 # the next un-allocated state, will be incremented as we go. next_state = 1 # the next un-allocated state, will be incremented as we go.
arcs = [] arcs = []
print('token2id ori: ', token2id)
print('word2id ori: ', word2id)
assert token2id["<eps>"] == 0 assert token2id["<eps>"] == 0
assert word2id["<eps>"] == 0 assert word2id["<eps>"] == 0
eps = 0 eps = 0
print('token2id new: ', token2id)
print('word2id new: ', word2id)
print(lexicon)
for word, tokens in lexicon: for word, tokens in lexicon:
assert len(tokens) > 0, f"{word} has no pronunciations" assert len(tokens) > 0, f"{word} has no pronunciations"
cur_state = loop_state cur_state = loop_state
@ -306,7 +299,7 @@ def lexicon_to_fst(
# the other one to the sil_state. # the other one to the sil_state.
i = len(tokens) - 1 i = len(tokens) - 1
w = word if i == 0 else eps w = word if i == 0 else eps
tokens[i] = tokens[i] if i >=0 else eps tokens[i] = tokens[i] if i >= 0 else eps
arcs.append([cur_state, loop_state, tokens[i], w, score]) arcs.append([cur_state, loop_state, tokens[i], w, score])
if need_self_loops: if need_self_loops:
@ -326,7 +319,6 @@ def lexicon_to_fst(
arcs = [[str(i) for i in arc] for arc in arcs] arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs] arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs) arcs = "\n".join(arcs)
print(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False) fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa return fsa
@ -334,7 +326,6 @@ def lexicon_to_fst(
def main(): def main():
args = get_args() args = get_args()
lang_dir = Path(args.lang_dir) lang_dir = Path(args.lang_dir)
#out_dir = Path("data/lang_phone")
lexicon_filename = lang_dir / "lexicon.txt" lexicon_filename = lang_dir / "lexicon.txt"
lexicon = read_lexicon(lexicon_filename) lexicon = read_lexicon(lexicon_filename)
@ -386,7 +377,7 @@ def main():
L.aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") L.aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
L_disambig.labels_sym = L.labels_sym L_disambig.labels_sym = L.labels_sym
L_disambig.aux_labels_sym = L.aux_labels_sym L_disambig.aux_labels_sym = L.aux_labels_sym
L.draw(out_dir / "L.png", title="L") L.draw(lang_dir / "L.png", title="L")
L_disambig.draw(lang_dir / "L_disambig.png", title="L_disambig") L_disambig.draw(lang_dir / "L_disambig.png", title="L_disambig")

View File

@ -64,25 +64,22 @@ def prepare_lexicon(manifests_dir: str, lang_dir: str):
lexicon = Path(lang_dir) / "lexicon.txt" lexicon = Path(lang_dir) / "lexicon.txt"
logging.info(f"Loading {supervisions_train}!") logging.info(f"Loading {supervisions_train}!")
with open(supervisions_train, 'r') as load_f: with open(supervisions_train, "r") as load_f:
load_dicts = json.load(load_f) load_dicts = json.load(load_f)
for load_dict in load_dicts: for load_dict in load_dicts:
idx = load_dict['id'] text = load_dict["text"]
text = load_dict['text'] phones_list = list(filter(None, text.split(" ")))
phones_list = list(filter(None, text.split(' ')))
for phone in phones_list: for phone in phones_list:
if phone not in phones: if phone not in phones:
phones.append(phone) phones.append(phone)
with open(lexicon, 'w') as f: with open(lexicon, "w") as f:
for phone in sorted(phones): for phone in sorted(phones):
f.write(str(phone) + " " + str(phone)) f.write(str(phone) + " " + str(phone))
f.write('\n') f.write("\n")
f.write("<UNK> <UNK>") f.write("<UNK> <UNK>")
f.write('\n') f.write("\n")
return lexicon
def main(): def main():
@ -90,9 +87,8 @@ def main():
manifests_dir = Path(args.manifests_dir) manifests_dir = Path(args.manifests_dir)
lang_dir = Path(args.lang_dir) lang_dir = Path(args.lang_dir)
logging.info(f"Generating lexicon.txt and train.text") logging.info("Generating lexicon.txt")
prepare_lexicon(manifests_dir, lang_dir)
lexicon_file = prepare_lexicon(manifests_dir, lang_dir)
if __name__ == "__main__": if __name__ == "__main__":
@ -103,4 +99,3 @@ if __name__ == "__main__":
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -311,26 +311,20 @@ class TimitAsrDataModule(DataModule):
@lru_cache() @lru_cache()
def train_cuts(self) -> CutSet: def train_cuts(self) -> CutSet:
logging.info("About to get train cuts") logging.info("About to get train cuts")
cuts_train = load_manifest( cuts_train = load_manifest(self.args.feature_dir / "cuts_TRAIN.json.gz")
self.args.feature_dir / "cuts_TRAIN.json.gz"
)
return cuts_train return cuts_train
@lru_cache() @lru_cache()
def valid_cuts(self) -> CutSet: def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts") logging.info("About to get dev cuts")
cuts_valid = load_manifest( cuts_valid = load_manifest(self.args.feature_dir / "cuts_DEV.json.gz")
self.args.feature_dir / "cuts_DEV.json.gz"
)
return cuts_valid return cuts_valid
@lru_cache() @lru_cache()
def test_cuts(self) -> CutSet: def test_cuts(self) -> CutSet:
logging.debug("About to get test cuts") logging.debug("About to get test cuts")
cuts_test = load_manifest( cuts_test = load_manifest(self.args.feature_dir / "cuts_TEST.json.gz")
self.args.feature_dir / "cuts_TEST.json.gz"
)
return cuts_test return cuts_test

View File

@ -449,7 +449,6 @@ def main():
) )
if params.avg == 1: if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
#load_checkpoint(f"tmp/icefall_asr_librispeech_tdnn-lstm_ctc/exp/pretrained.pt", model)
else: else:
start = params.epoch - params.avg + 1 start = params.epoch - params.avg + 1
filenames = [] filenames = []
@ -470,18 +469,6 @@ def main():
model.eval() model.eval()
timit = TimitAsrDataModule(args) timit = TimitAsrDataModule(args)
# CAUTION: `test_sets` is for displaying only.
# If you want to skip test-clean, you have to skip
# it inside the for loop. That is, use
#
# if test_set == 'test-clean': continue
#
#test_sets = ["test-clean", "test-other"]
#test_sets = ["test-other"]
#for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
#if test_set == "test-clean": continue
#if test_set == "test-other": break
test_set = "TEST"
test_dl = timit.test_dataloaders() test_dl = timit.test_dataloaders()
results_dict = decode_dataset( results_dict = decode_dataset(
dl=test_dl, dl=test_dl,
@ -491,7 +478,7 @@ def main():
lexicon=lexicon, lexicon=lexicon,
G=G, G=G,
) )
test_set = "TEST"
save_results( save_results(
params=params, test_set_name=test_set, results_dict=results_dict params=params, test_set_name=test_set, results_dict=results_dict
) )