diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 72f39ef40..a1c6de8b9 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -333,7 +333,7 @@ def main(): logging.info(f"device: {device}") HLG = k2.Fsa.from_dict( - torch.load("data/lang_phone/HLG.pt", map_location="cpu") + torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") ) HLG = HLG.to(device) assert HLG.requires_grad is False diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py index 8794fcc6d..a3df0a632 100755 --- a/egs/yesno/ASR/tdnn/decode.py +++ b/egs/yesno/ASR/tdnn/decode.py @@ -61,20 +61,11 @@ def get_params() -> AttributeDict: "lang_dir": Path("data/lang_phone"), "lm_dir": Path("data/lm"), "feature_dim": 23, - "subsampling_factor": 1, "search_beam": 20, - "output_beam": 5, + "output_beam": 8, "min_active_states": 30, "max_active_states": 10000, "use_double_scores": True, - # Possible values for method: - # - 1best - # - nbest - # - nbest-rescoring - # - whole-lattice-rescoring - "method": "1best", - # num_paths is used when method is "nbest" and "nbest-rescoring" - "num_paths": 30, } ) return params @@ -85,29 +76,17 @@ def decode_one_batch( model: nn.Module, HLG: k2.Fsa, batch: dict, - lexicon: Lexicon, - G: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[int]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: + word_table: k2.SymbolTable, +) -> List[List[int]]: + """Decode one batch and return the result in a list-of-list. + Each sub list contains the word IDs for an utterance in the batch. - - key: It indicates the setting used for decoding. For example, - if no rescoring is used, the key is the string `no_rescore`. - If LM rescoring is used, the key is the string `lm_scale_xxx`, - where `xxx` is the value of `lm_scale`. An example key is - `lm_scale_0.7` - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. Args: params: It's the return value of :func:`get_params`. - - params.method is "1best", it uses 1best decoding without LM rescoring. - - params.method is "nbest", it uses nbest decoding without LM rescoring. - - params.method is "nbest-rescoring", it uses nbest LM rescoring. - - params.method is "whole-lattice-rescoring", it uses whole lattice LM - rescoring. + - params.method is "1best", it uses 1best decoding. + - params.method is "nbest", it uses nbest decoding. model: The neural model. @@ -117,15 +96,11 @@ def decode_one_batch( It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. - lexicon: - It contains word symbol table. - G: - An LM. It is not None when params.method is "nbest-rescoring" - or "whole-lattice-rescoring". In general, the G in HLG - is a 3-gram LM, while this G is a 4-gram LM. + (https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py) + word_table: + It is the word symbol table. Returns: - Return the decoding result. See above description for the format of - the returned dict. + Return the decoding result. `len(ans)` == batch size. """ device = HLG.device feature = batch["inputs"] @@ -141,8 +116,8 @@ def decode_one_batch( supervision_segments = torch.stack( ( supervisions["sequence_idx"], - supervisions["start_frame"] // params.subsampling_factor, - supervisions["num_frames"] // params.subsampling_factor, + supervisions["start_frame"], + supervisions["num_frames"], ), 1, ).to(torch.int32) @@ -157,46 +132,12 @@ def decode_one_batch( max_active_states=params.max_active_states, ) - if params.method in ["1best", "nbest"]: - if params.method == "1best": - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - key = "no_rescore" - else: - best_path = nbest_decoding( - lattice=lattice, - num_paths=params.num_paths, - use_double_scores=params.use_double_scores, - ) - key = f"no_rescore-{params.num_paths}" - hyps = get_texts(best_path) - hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] - return {key: hyps} - - assert params.method in ["nbest-rescoring", "whole-lattice-rescoring"] - - lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] - lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] - - if params.method == "nbest-rescoring": - best_path_dict = rescore_with_n_best_list( - lattice=lattice, - G=G, - num_paths=params.num_paths, - lm_scale_list=lm_scale_list, - ) - else: - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list - ) - - ans = dict() - for lm_scale_str, best_path in best_path_dict.items(): - hyps = get_texts(best_path) - hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] - ans[lm_scale_str] = hyps - return ans + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + return hyps def decode_dataset( @@ -204,9 +145,8 @@ def decode_dataset( params: AttributeDict, model: nn.Module, HLG: k2.Fsa, - lexicon: Lexicon, - G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[int], List[int]]]]: + word_table: k2.SymbolTable, +) -> List[Tuple[List[int], List[int]]]: """Decode dataset. Args: @@ -218,16 +158,10 @@ def decode_dataset( The neural model. HLG: The decoding graph. - lexicon: - It contains word symbol table. - G: - An LM. It is not None when params.method is "nbest-rescoring" - or "whole-lattice-rescoring". In general, the G in HLG - is a 3-gram LM, while this G is a 4-gram LM. + word_table: + It is word symbol table. Returns: - Return a dict, whose key may be "no-rescore" if no LM rescoring - is used, or it may be "lm_scale_0.7" if LM rescoring is used. - Its value is a list of tuples. Each tuple contains two elements: + Return a tuple contains two elements (ref_text, hyp_text): The first is the reference transcript, and the second is the predicted result. """ @@ -240,27 +174,25 @@ def decode_dataset( except TypeError: num_batches = "?" - results = defaultdict(list) + results = [] for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] - hyps_dict = decode_one_batch( + hyps = decode_one_batch( params=params, model=model, HLG=HLG, batch=batch, - lexicon=lexicon, - G=G, + word_table=word_table, ) - for lm_scale, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): - ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch = [] + assert len(hyps) == len(texts) + for hyp_words, ref_text in zip(hyps, texts): + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) - results[lm_scale].extend(this_batch) + results.extend(this_batch) num_cuts += len(batch["supervisions"]["text"]) @@ -274,38 +206,46 @@ def decode_dataset( def save_results( - params: AttributeDict, + exp_dir: Path, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") + results: List[Tuple[List[int], List[int]]], +) -> None: + """Save results to `exp_dir`. + Args: + exp_dir: + The output directory. This function create the following files inside + this directory: - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats(f, f"{test_set_name}-{key}", results) - test_set_wers[key] = wer + - recogs-{test_set_name}.text - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + It contains the reference and hypothesis results, like below:: - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) + ref=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES'] + hyp=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES'] + ref=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES'] + hyp=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES'] - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) + - errs-{test_set_name}.txt + + It contains the detailed WER. + test_set_name: + The name of the test set, which will be part of the result filename. + results: + A list of tuples, each of which contains (ref_words, hyp_words). + Returns: + Return None. + """ + recog_path = exp_dir / f"recogs-{test_set_name}.txt" + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = exp_dir / f"errs-{test_set_name}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}", results) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) @torch.no_grad() @@ -322,7 +262,7 @@ def main(): logging.info(params) lexicon = Lexicon(params.lang_dir) - max_phone_id = max(lexicon.tokens) + max_token_id = max(lexicon.tokens) device = torch.device("cpu") if torch.cuda.is_available(): @@ -331,53 +271,14 @@ def main(): logging.info(f"device: {device}") HLG = k2.Fsa.from_dict( - torch.load("data/lang_phone/HLG.pt", map_location="cpu") + torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") ) HLG = HLG.to(device) assert HLG.requires_grad is False - if not hasattr(HLG, "lm_scores"): - HLG.lm_scores = HLG.scores.clone() - - if params.method in ["nbest-rescoring", "whole-lattice-rescoring"]: - if not (params.lm_dir / "G_4_gram.pt").is_file(): - logging.info("Loading G_4_gram.fst.txt") - logging.warning("It may take 8 minutes.") - with open(params.lm_dir / "G_4_gram.fst.txt") as f: - first_word_disambig_id = lexicon.word_table["#0"] - - G = k2.Fsa.from_openfst(f.read(), acceptor=False) - # G.aux_labels is not needed in later computations, so - # remove it here. - del G.aux_labels - # CAUTION: The following line is crucial. - # Arcs entering the back-off state have label equal to #0. - # We have to change it to 0 here. - G.labels[G.labels >= first_word_disambig_id] = 0 - G = k2.Fsa.from_fsas([G]).to(device) - G = k2.arc_sort(G) - torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") - else: - logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu") - G = k2.Fsa.from_dict(d).to(device) - - if params.method == "whole-lattice-rescoring": - # Add epsilon self-loops to G as we will compose - # it with the whole lattice later - G = k2.add_epsilon_self_loops(G) - G = k2.arc_sort(G) - G = G.to(device) - - # G.lm_scores is used to replace HLG.lm_scores during - # LM rescoring. - G.lm_scores = G.scores.clone() - else: - G = None - model = Tdnn( num_features=params.feature_dim, - num_classes=max_phone_id + 1, # +1 for the blank symbol + num_classes=max_token_id + 1, # +1 for the blank symbol ) if params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) @@ -394,26 +295,18 @@ def main(): model.eval() yes_no = YesNoAsrDataModule(args) - # CAUTION: `test_sets` is for displaying only. - # If you want to skip test-clean, you have to skip - # it inside the for loop. That is, use - # - # if test_set == 'test-clean': continue - # - test_sets = ["test"] - for test_set, test_dl in zip(test_sets, [yes_no.test_dataloaders()]): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - HLG=HLG, - lexicon=lexicon, - G=G, - ) + test_dl = yes_no.test_dataloaders() + results = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + word_table=lexicon.word_table, + ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results( + exp_dir=params.exp_dir, test_set_name="test_set", results=results + ) logging.info("Done!") diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py index a66b16db2..7cce8a54a 100755 --- a/egs/yesno/ASR/tdnn/train.py +++ b/egs/yesno/ASR/tdnn/train.py @@ -496,6 +496,10 @@ def run(rank, world_size, args): yes_no = YesNoAsrDataModule(args) train_dl = yes_no.train_dataloaders() + + # There are only 60 waves: 30 files are used for training + # and the remaining 30 files are used for testing. + # We use test data as validation. valid_dl = yes_no.test_dataloaders() for epoch in range(params.start_epoch, params.num_epochs):