From 4beb25c50b0ccda8a3c4851e6ac994caddeccad6 Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Thu, 28 Oct 2021 14:45:24 +0800 Subject: [PATCH] Update timit recipe --- egs/timit/ASR/local/compile_hlg.py | 15 ++-------- egs/timit/ASR/local/prepare_lang.py | 17 +++-------- egs/timit/ASR/local/prepare_lexicon.py | 29 ++++++++----------- egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py | 16 ++++------ egs/timit/ASR/tdnn_lstm_ctc/decode.py | 17 ++--------- 5 files changed, 25 insertions(+), 69 deletions(-) diff --git a/egs/timit/ASR/local/compile_hlg.py b/egs/timit/ASR/local/compile_hlg.py index ad8b41de1..58cab4cf2 100644 --- a/egs/timit/ASR/local/compile_hlg.py +++ b/egs/timit/ASR/local/compile_hlg.py @@ -54,7 +54,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: """ Args: lang_dir: - The language directory, e.g., data/lang_phone or data/lang_bpe_5000. + The language directory, e.g., data/lang_phone. Return: An FSA representing HLG. @@ -63,18 +63,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: max_token_id = max(lexicon.tokens) logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") H = k2.ctc_topo(max_token_id) - - if Path(lang_dir / "L_disambig.pt").is_file(): - logging.info("Loading L_disambig.pt") - d = torch.load(Path(lang_dir/"L_disambig.pt")) - L = k2.Fsa.from_dict(d) - else: - logging.info("Loading L_disambig.fst.txt") - with open(Path(lang_dir/"L_disambig.fst.txt")) as f: - L = k2.Fsa.from_openfst(f.read(), acceptor=False) - torch.save(L_disambig.as_dict(), Path(lang_dir / "L_disambig.pt")) - - #L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) if Path("data/lm/G.pt").is_file(): logging.info("Loading pre-compiled G") diff --git a/egs/timit/ASR/local/prepare_lang.py b/egs/timit/ASR/local/prepare_lang.py index 80ba015cb..5dca77146 100644 --- a/egs/timit/ASR/local/prepare_lang.py +++ b/egs/timit/ASR/local/prepare_lang.py @@ -106,7 +106,7 @@ def get_tokens(lexicon: Lexicon) -> List[str]: ans = set() for _, tokens in lexicon: ans.update(tokens) - #sorted_ans = sorted(list(ans)) + sorted_ans = list(ans) return sorted_ans @@ -275,18 +275,11 @@ def lexicon_to_fst( loop_state = 0 # words enter and leave from here next_state = 1 # the next un-allocated state, will be incremented as we go. arcs = [] - - print('token2id ori: ', token2id) - print('word2id ori: ', word2id) assert token2id[""] == 0 assert word2id[""] == 0 eps = 0 - print('token2id new: ', token2id) - print('word2id new: ', word2id) - - print(lexicon) for word, tokens in lexicon: assert len(tokens) > 0, f"{word} has no pronunciations" cur_state = loop_state @@ -306,7 +299,7 @@ def lexicon_to_fst( # the other one to the sil_state. i = len(tokens) - 1 w = word if i == 0 else eps - tokens[i] = tokens[i] if i >=0 else eps + tokens[i] = tokens[i] if i >= 0 else eps arcs.append([cur_state, loop_state, tokens[i], w, score]) if need_self_loops: @@ -326,7 +319,6 @@ def lexicon_to_fst( arcs = [[str(i) for i in arc] for arc in arcs] arcs = [" ".join(arc) for arc in arcs] arcs = "\n".join(arcs) - print(arcs) fsa = k2.Fsa.from_str(arcs, acceptor=False) return fsa @@ -334,9 +326,8 @@ def lexicon_to_fst( def main(): args = get_args() lang_dir = Path(args.lang_dir) - #out_dir = Path("data/lang_phone") lexicon_filename = lang_dir / "lexicon.txt" - + lexicon = read_lexicon(lexicon_filename) tokens = get_tokens(lexicon) @@ -386,7 +377,7 @@ def main(): L.aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") L_disambig.labels_sym = L.labels_sym L_disambig.aux_labels_sym = L.aux_labels_sym - L.draw(out_dir / "L.png", title="L") + L.draw(lang_dir / "L.png", title="L") L_disambig.draw(lang_dir / "L_disambig.png", title="L_disambig") diff --git a/egs/timit/ASR/local/prepare_lexicon.py b/egs/timit/ASR/local/prepare_lexicon.py index 65c1dca44..2f5afb497 100644 --- a/egs/timit/ASR/local/prepare_lexicon.py +++ b/egs/timit/ASR/local/prepare_lexicon.py @@ -59,48 +59,43 @@ def prepare_lexicon(manifests_dir: str, lang_dir: str): The lexicon.txt file and the train.text in lang_dir. """ phones = [] - + supervisions_train = Path(manifests_dir) / "supervisions_TRAIN.json" lexicon = Path(lang_dir) / "lexicon.txt" - + logging.info(f"Loading {supervisions_train}!") - with open(supervisions_train, 'r') as load_f: + with open(supervisions_train, "r") as load_f: load_dicts = json.load(load_f) for load_dict in load_dicts: - idx = load_dict['id'] - text = load_dict['text'] - phones_list = list(filter(None, text.split(' '))) + text = load_dict["text"] + phones_list = list(filter(None, text.split(" "))) for phone in phones_list: if phone not in phones: phones.append(phone) - - with open(lexicon, 'w') as f: + + with open(lexicon, "w") as f: for phone in sorted(phones): f.write(str(phone) + " " + str(phone)) - f.write('\n') + f.write("\n") f.write(" ") - f.write('\n') + f.write("\n") - return lexicon - def main(): args = get_args() manifests_dir = Path(args.manifests_dir) lang_dir = Path(args.lang_dir) - logging.info(f"Generating lexicon.txt and train.text") + logging.info("Generating lexicon.txt") + prepare_lexicon(manifests_dir, lang_dir) - lexicon_file = prepare_lexicon(manifests_dir, lang_dir) - if __name__ == "__main__": formatter = ( "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + ) logging.basicConfig(format=formatter, level=logging.INFO) main() - diff --git a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py index 078a4be89..02831469b 100644 --- a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -311,26 +311,20 @@ class TimitAsrDataModule(DataModule): @lru_cache() def train_cuts(self) -> CutSet: logging.info("About to get train cuts") - cuts_train = load_manifest( - self.args.feature_dir / "cuts_TRAIN.json.gz" - ) + cuts_train = load_manifest(self.args.feature_dir / "cuts_TRAIN.json.gz") return cuts_train @lru_cache() def valid_cuts(self) -> CutSet: logging.info("About to get dev cuts") - cuts_valid = load_manifest( - self.args.feature_dir / "cuts_DEV.json.gz" - ) + cuts_valid = load_manifest(self.args.feature_dir / "cuts_DEV.json.gz") return cuts_valid @lru_cache() - def test_cuts(self) -> CutSet: + def test_cuts(self) -> CutSet: logging.debug("About to get test cuts") - cuts_test = load_manifest( - self.args.feature_dir / "cuts_TEST.json.gz" - ) - + cuts_test = load_manifest(self.args.feature_dir / "cuts_TEST.json.gz") + return cuts_test diff --git a/egs/timit/ASR/tdnn_lstm_ctc/decode.py b/egs/timit/ASR/tdnn_lstm_ctc/decode.py index 0f90a6e9f..9b33f567b 100644 --- a/egs/timit/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/timit/ASR/tdnn_lstm_ctc/decode.py @@ -310,7 +310,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] - + hyps_dict = decode_one_batch( params=params, model=model, @@ -449,7 +449,6 @@ def main(): ) if params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - #load_checkpoint(f"tmp/icefall_asr_librispeech_tdnn-lstm_ctc/exp/pretrained.pt", model) else: start = params.epoch - params.avg + 1 filenames = [] @@ -470,18 +469,6 @@ def main(): model.eval() timit = TimitAsrDataModule(args) - # CAUTION: `test_sets` is for displaying only. - # If you want to skip test-clean, you have to skip - # it inside the for loop. That is, use - # - # if test_set == 'test-clean': continue - # - #test_sets = ["test-clean", "test-other"] - #test_sets = ["test-other"] - #for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): - #if test_set == "test-clean": continue - #if test_set == "test-other": break - test_set = "TEST" test_dl = timit.test_dataloaders() results_dict = decode_dataset( dl=test_dl, @@ -491,7 +478,7 @@ def main(): lexicon=lexicon, G=G, ) - + test_set = "TEST" save_results( params=params, test_set_name=test_set, results_dict=results_dict )