Update timit recipe

This commit is contained in:
Mingshuang Luo 2021-10-28 14:45:24 +08:00
parent e023a9df98
commit 4beb25c50b
5 changed files with 25 additions and 69 deletions

View File

@ -54,7 +54,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
"""
Args:
lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
The language directory, e.g., data/lang_phone.
Return:
An FSA representing HLG.
@ -63,18 +63,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
max_token_id = max(lexicon.tokens)
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id)
if Path(lang_dir / "L_disambig.pt").is_file():
logging.info("Loading L_disambig.pt")
d = torch.load(Path(lang_dir/"L_disambig.pt"))
L = k2.Fsa.from_dict(d)
else:
logging.info("Loading L_disambig.fst.txt")
with open(Path(lang_dir/"L_disambig.fst.txt")) as f:
L = k2.Fsa.from_openfst(f.read(), acceptor=False)
torch.save(L_disambig.as_dict(), Path(lang_dir / "L_disambig.pt"))
#L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
if Path("data/lm/G.pt").is_file():
logging.info("Loading pre-compiled G")

View File

@ -106,7 +106,7 @@ def get_tokens(lexicon: Lexicon) -> List[str]:
ans = set()
for _, tokens in lexicon:
ans.update(tokens)
#sorted_ans = sorted(list(ans))
sorted_ans = list(ans)
return sorted_ans
@ -276,17 +276,10 @@ def lexicon_to_fst(
next_state = 1 # the next un-allocated state, will be incremented as we go.
arcs = []
print('token2id ori: ', token2id)
print('word2id ori: ', word2id)
assert token2id["<eps>"] == 0
assert word2id["<eps>"] == 0
eps = 0
print('token2id new: ', token2id)
print('word2id new: ', word2id)
print(lexicon)
for word, tokens in lexicon:
assert len(tokens) > 0, f"{word} has no pronunciations"
cur_state = loop_state
@ -326,7 +319,6 @@ def lexicon_to_fst(
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
print(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa
@ -334,7 +326,6 @@ def lexicon_to_fst(
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
#out_dir = Path("data/lang_phone")
lexicon_filename = lang_dir / "lexicon.txt"
lexicon = read_lexicon(lexicon_filename)
@ -386,7 +377,7 @@ def main():
L.aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
L_disambig.labels_sym = L.labels_sym
L_disambig.aux_labels_sym = L.aux_labels_sym
L.draw(out_dir / "L.png", title="L")
L.draw(lang_dir / "L.png", title="L")
L_disambig.draw(lang_dir / "L_disambig.png", title="L_disambig")

View File

@ -64,25 +64,22 @@ def prepare_lexicon(manifests_dir: str, lang_dir: str):
lexicon = Path(lang_dir) / "lexicon.txt"
logging.info(f"Loading {supervisions_train}!")
with open(supervisions_train, 'r') as load_f:
with open(supervisions_train, "r") as load_f:
load_dicts = json.load(load_f)
for load_dict in load_dicts:
idx = load_dict['id']
text = load_dict['text']
phones_list = list(filter(None, text.split(' ')))
text = load_dict["text"]
phones_list = list(filter(None, text.split(" ")))
for phone in phones_list:
if phone not in phones:
phones.append(phone)
with open(lexicon, 'w') as f:
with open(lexicon, "w") as f:
for phone in sorted(phones):
f.write(str(phone) + " " + str(phone))
f.write('\n')
f.write("\n")
f.write("<UNK> <UNK>")
f.write('\n')
return lexicon
f.write("\n")
def main():
@ -90,9 +87,8 @@ def main():
manifests_dir = Path(args.manifests_dir)
lang_dir = Path(args.lang_dir)
logging.info(f"Generating lexicon.txt and train.text")
lexicon_file = prepare_lexicon(manifests_dir, lang_dir)
logging.info("Generating lexicon.txt")
prepare_lexicon(manifests_dir, lang_dir)
if __name__ == "__main__":
@ -103,4 +99,3 @@ if __name__ == "__main__":
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -311,26 +311,20 @@ class TimitAsrDataModule(DataModule):
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
cuts_train = load_manifest(
self.args.feature_dir / "cuts_TRAIN.json.gz"
)
cuts_train = load_manifest(self.args.feature_dir / "cuts_TRAIN.json.gz")
return cuts_train
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
cuts_valid = load_manifest(
self.args.feature_dir / "cuts_DEV.json.gz"
)
cuts_valid = load_manifest(self.args.feature_dir / "cuts_DEV.json.gz")
return cuts_valid
@lru_cache()
def test_cuts(self) -> CutSet:
logging.debug("About to get test cuts")
cuts_test = load_manifest(
self.args.feature_dir / "cuts_TEST.json.gz"
)
cuts_test = load_manifest(self.args.feature_dir / "cuts_TEST.json.gz")
return cuts_test

View File

@ -449,7 +449,6 @@ def main():
)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
#load_checkpoint(f"tmp/icefall_asr_librispeech_tdnn-lstm_ctc/exp/pretrained.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
@ -470,18 +469,6 @@ def main():
model.eval()
timit = TimitAsrDataModule(args)
# CAUTION: `test_sets` is for displaying only.
# If you want to skip test-clean, you have to skip
# it inside the for loop. That is, use
#
# if test_set == 'test-clean': continue
#
#test_sets = ["test-clean", "test-other"]
#test_sets = ["test-other"]
#for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
#if test_set == "test-clean": continue
#if test_set == "test-other": break
test_set = "TEST"
test_dl = timit.test_dataloaders()
results_dict = decode_dataset(
dl=test_dl,
@ -491,7 +478,7 @@ def main():
lexicon=lexicon,
G=G,
)
test_set = "TEST"
save_results(
params=params, test_set_name=test_set, results_dict=results_dict
)