From f65854cca5aeb6c1c6e473f620bd9fc914a4cad4 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 27 Jul 2021 17:38:47 +0800 Subject: [PATCH] Add BPE decoding results. --- egs/librispeech/ASR/README.md | 117 ++++++ egs/librispeech/ASR/conformer_ctc/decode.py | 388 ++++++++++++++++++-- egs/librispeech/ASR/local/compile_hlg.py | 4 +- egs/librispeech/ASR/tdnn_lstm_ctc/decode.py | 2 + icefall/decode.py | 7 +- 5 files changed, 479 insertions(+), 39 deletions(-) diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index 71c333aaf..45c9ef4de 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -2,3 +2,120 @@ Run `./prepare.sh` to prepare the data. Run `./xxx_train.py` (to be added) to train a model. + +## Conformer-CTC +Results of the pre-trained model from +`` +are given below + +### HLG - no LM rescoring + +(output beam size is 8) + +#### 1-best decoding + +``` +[test-clean-no_rescore] %WER 3.15% [1656 / 52576, 127 ins, 377 del, 1152 sub ] +[test-other-no_rescore] %WER 7.03% [3682 / 52343, 220 ins, 1024 del, 2438 sub ] +``` + +#### n-best decoding + +For n=100, + +``` +[test-clean-no_rescore-100] %WER 3.15% [1656 / 52576, 127 ins, 377 del, 1152 sub ] +[test-other-no_rescore-100] %WER 7.14% [3737 / 52343, 275 ins, 1020 del, 2442 sub ] +``` + +For n=200, + +``` +[test-clean-no_rescore-200] %WER 3.16% [1660 / 52576, 125 ins, 378 del, 1157 sub ] +[test-other-no_rescore-200] %WER 7.04% [3684 / 52343, 228 ins, 1012 del, 2444 sub ] +``` + +### HLG - with LM rescoring + +#### Whole lattice rescoring + +``` +[test-clean-lm_scale_0.8] %WER 2.77% [1456 / 52576, 150 ins, 210 del, 1096 sub ] +[test-other-lm_scale_0.8] %WER 6.23% [3262 / 52343, 246 ins, 635 del, 2381 sub ] +``` + +WERs of different LM scales are: + +``` +For test-clean, WER of different settings are: +lm_scale_0.8 2.77 best for test-clean +lm_scale_0.9 2.87 +lm_scale_1.0 3.06 +lm_scale_1.1 3.34 +lm_scale_1.2 3.71 +lm_scale_1.3 4.18 +lm_scale_1.4 4.8 +lm_scale_1.5 5.48 +lm_scale_1.6 6.08 +lm_scale_1.7 6.79 +lm_scale_1.8 7.49 +lm_scale_1.9 8.14 +lm_scale_2.0 8.82 + +For test-other, WER of different settings are: +lm_scale_0.8 6.23 best for test-other +lm_scale_0.9 6.37 +lm_scale_1.0 6.62 +lm_scale_1.1 6.99 +lm_scale_1.2 7.46 +lm_scale_1.3 8.13 +lm_scale_1.4 8.84 +lm_scale_1.5 9.61 +lm_scale_1.6 10.32 +lm_scale_1.7 11.17 +lm_scale_1.8 12.12 +lm_scale_1.9 12.93 +lm_scale_2.0 13.77 +``` + +#### n-best LM rescoring + +n = 100 + +``` +[test-clean-lm_scale_0.8] %WER 2.79% [1469 / 52576, 149 ins, 212 del, 1108 sub ] +[test-other-lm_scale_0.8] %WER 6.36% [3329 / 52343, 259 ins, 666 del, 2404 sub ] +``` + +WERs of different LM scales are: + +``` +For test-clean, WER of different settings are: +lm_scale_0.8 2.79 best for test-clean +lm_scale_0.9 2.89 +lm_scale_1.0 3.03 +lm_scale_1.1 3.28 +lm_scale_1.2 3.52 +lm_scale_1.3 3.78 +lm_scale_1.4 4.04 +lm_scale_1.5 4.24 +lm_scale_1.6 4.45 +lm_scale_1.7 4.58 +lm_scale_1.8 4.7 +lm_scale_1.9 4.8 +lm_scale_2.0 4.92 +For test-other, WER of different settings are: +lm_scale_0.8 6.36 best for test-other +lm_scale_0.9 6.45 +lm_scale_1.0 6.64 +lm_scale_1.1 6.92 +lm_scale_1.2 7.25 +lm_scale_1.3 7.59 +lm_scale_1.4 7.88 +lm_scale_1.5 8.13 +lm_scale_1.6 8.36 +lm_scale_1.7 8.54 +lm_scale_1.8 8.71 +lm_scale_1.9 8.88 +lm_scale_2.0 9.02 +``` diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 74fd3060b..a9d2b465c 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -6,13 +6,25 @@ import argparse import logging +from collections import defaultdict from pathlib import Path +from typing import Dict, List, Optional, Tuple +import k2 import torch +import torch.nn as nn from conformer import Conformer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.dataset.librispeech import LibriSpeechAsrDataModule +from icefall.decode import ( + get_lattice, + nbest_decoding, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, get_texts, @@ -22,40 +34,6 @@ from icefall.utils import ( ) -def get_params() -> AttributeDict: - params = AttributeDict( - { - "exp_dir": Path("conformer_ctc/exp"), - "lang_dir": Path("data/lang/bpe"), - "lm_dir": Path("data/lm"), - "feature_dim": 80, - "nhead": 8, - "attention_dim": 512, - "num_classes": 5000, - "subsampling_factor": 4, - "num_decoder_layers": 6, - "vgg_frontend": False, - "is_espnet_structure": True, - "mmi_loss": False, - "use_feat_batchnorm": True, - "search_beam": 20, - "output_beam": 5, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - # Possible values for method: - # - 1best - # - nbest - # - nbest-rescoring - # - whole-lattice-rescoring - "method": "whole-lattice-rescoring", - # num_paths is used when method is "nbest" and "nbest-rescoring" - "num_paths": 30, - } - ) - return params - - def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -79,6 +57,270 @@ def get_parser(): return parser +def get_params() -> AttributeDict: + params = AttributeDict( + { + "exp_dir": Path("conformer_ctc/exp"), + "lang_dir": Path("data/lang/bpe"), + "lm_dir": Path("data/lm"), + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "subsampling_factor": 4, + "num_decoder_layers": 6, + "vgg_frontend": False, + "is_espnet_structure": True, + "mmi_loss": False, + "use_feat_batchnorm": True, + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + # Possible values for method: + # - 1best + # - nbest + # - nbest-rescoring + # - whole-lattice-rescoring + "method": "nbest-rescoring", + # num_paths is used when method is "nbest" and "nbest-rescoring" + "num_paths": 100, + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + HLG: k2.Fsa, + batch: dict, + lexicon: Lexicon, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[int]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + + - params.method is "1best", it uses 1best decoding without LM rescoring. + - params.method is "nbest", it uses nbest decoding without LM rescoring. + - params.method is "nbest-rescoring", it uses nbest LM rescoring. + - params.method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + HLG: + The decoding graph. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + lexicon: + It contains word symbol table. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = HLG.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is [N, T, C] + + feature = feature.permute(0, 2, 1) # now feature is [N, C, T] + + supervisions = batch["supervisions"] + + nnet_output, encoder_memory, memory_mask = model(feature, supervisions) + # nnet_output is [N, C, T] + + nnet_output = nnet_output.permute(0, 2, 1) + # now nnet_output is [N, T, C] + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"] // params.subsampling_factor, + supervisions["num_frames"] // params.subsampling_factor, + ), + 1, + ).to(torch.int32) + + lattice = get_lattice( + nnet_output=nnet_output, + HLG=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method in ["1best", "nbest"]: + if params.method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no_rescore" + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + ) + key = f"no_rescore-{params.num_paths}" + + hyps = get_texts(best_path) + hyps = [[lexicon.words[i] for i in ids] for ids in hyps] + return {key: hyps} + + assert params.method in ["nbest-rescoring", "whole-lattice-rescoring"] + + lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if params.method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + ) + else: + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list + ) + + ans = dict() + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[lexicon.words[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: k2.Fsa, + lexicon: Lexicon, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[int], List[int]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. + lexicon: + It contains word symbol table. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + results = [] + + num_cuts = 0 + tot_num_cuts = len(dl.dataset.cuts) + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + + hyps_dict = decode_one_batch( + params=params, + model=model, + HLG=HLG, + batch=batch, + lexicon=lexicon, + G=G, + ) + + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for hyp_words, ref_text in zip(hyps, texts): + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + results[lm_scale].extend(this_batch) + + num_cuts += len(batch["supervisions"]["text"]) + + if batch_idx % 100 == 0: + logging.info( + f"batch {batch_idx}, cuts processed until now is " + f"{num_cuts}/{tot_num_cuts} " + f"({float(num_cuts)/tot_num_cuts*100:.6f}%)" + ) + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}-{key}", results) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + @torch.no_grad() def main(): parser = get_parser() @@ -92,15 +334,64 @@ def main(): logging.info("Decoding started") logging.info(params) + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", 0) + logging.info(f"device: {device}") + + HLG = k2.Fsa.from_dict(torch.load(f"{params.lm_dir}/HLG_bpe.pt")) + HLG = HLG.to(device) + assert HLG.requires_grad is False + + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.method in ["nbest-rescoring", "whole-lattice-rescoring"]: + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.words["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt") + G = k2.Fsa.from_dict(d).to(device) + + if params.method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + model = Conformer( num_features=params.feature_dim, nhead=params.nhead, d_model=params.attention_dim, - num_classes=params.num_classes, + num_classes=num_classes, subsampling_factor=params.subsampling_factor, num_decoder_layers=params.num_decoder_layers, vgg_frontend=params.vgg_frontend, @@ -122,7 +413,32 @@ def main(): model.to(device) model.eval() - token_ids_with_blank = list(range(params.num_classes)) + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + # CAUTION: `test_sets` is for displaying only. + # If you want to skip test-clean, you have to skip + # it inside the for loop. That is, use + # + # if test_set == 'test-clean': continue + # + test_sets = ["test-clean", "test-other"] + for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + lexicon=lexicon, + G=G, + ) + + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) + + logging.info("Done!") if __name__ == "__main__": diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index c30bf9fba..dc9105418 100644 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -39,7 +39,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: if Path("data/lm/G_3_gram.pt").is_file(): print("Loading pre-compiled G_3_gram") d = torch.load("data/lm/G_3_gram.pt") - G = k2.Fsa.from_dict(d).to(device) + G = k2.Fsa.from_dict(d) else: print("Loading G_3_gram.fst.txt") with open("data/lm/G_3_gram.fst.txt") as f: @@ -114,7 +114,7 @@ def bpe_based_HLG(): print("Compiling BPE based HLG") HLG = compile_HLG("data/lang/bpe") - print("Saving HLG.pt to data/lm") + print("Saving HLG_bpe.pt to data/lm") torch.save(HLG.as_dict(), "data/lm/HLG_bpe.pt") diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 2a6dc671e..2a29190c9 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -326,6 +326,8 @@ def main(): if torch.cuda.is_available(): device = torch.device("cuda", 0) + logging.info(f"device: {device}") + HLG = k2.Fsa.from_dict(torch.load("data/lm/HLG.pt")) HLG = HLG.to(device) assert HLG.requires_grad is False diff --git a/icefall/decode.py b/icefall/decode.py index bb8d0c10e..0ab712b3b 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -54,6 +54,7 @@ def get_lattice( output_beam: float, min_active_states: int, max_active_states: int, + subsampling_factor: int = 1, ): """Get the decoding lattice from a decoding graph and neural network output. @@ -87,10 +88,14 @@ def get_lattice( frame for any given intersection/composition task. This is advisory, in that it will try not to exceed that but may not always succeed. You can use a very large number if no constraint is needed. + subsampling_factor: + The subsampling factor of the model. Returns: A lattice containing the decoding result. """ - dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, supervision_segments, allow_truncate=subsampling_factor - 1 + ) lattice = k2.intersect_dense_pruned( HLG,