diff --git a/docs/source/recipes/aishell/conformer_ctc.rst b/docs/source/recipes/aishell/conformer_ctc.rst index 96ede5b68..d225be9c6 100644 --- a/docs/source/recipes/aishell/conformer_ctc.rst +++ b/docs/source/recipes/aishell/conformer_ctc.rst @@ -97,13 +97,17 @@ Configurable options shows you the training options that can be passed from the commandline. The following options are used quite often: + - ``--exp-dir`` + + The experiment folder to save logs and model checkpoints, + default ``./conformer_ctc/exp``. - ``--num-epochs`` It is the number of epochs to train. For instance, ``./conformer_ctc/train.py --num-epochs 30`` trains for 30 epochs and generates ``epoch-0.pt``, ``epoch-1.pt``, ..., ``epoch-29.pt`` - in the folder ``./conformer_ctc/exp``. + in the folder set with ``--exp-dir``. - ``--start-epoch`` @@ -174,7 +178,7 @@ Pre-configured options ~~~~~~~~~~~~~~~~~~~~~~ There are some training options, e.g., weight decay, -number of warmup steps, results dir, etc, +number of warmup steps, etc, that are not passed from the commandline. They are pre-configured by the function ``get_params()`` in `conformer_ctc/train.py `_ @@ -192,8 +196,8 @@ them, please modify ``./conformer_ctc/train.py`` directly. Training logs ~~~~~~~~~~~~~ -Training logs and checkpoints are saved in ``conformer_ctc/exp``. -You will find the following files in that directory: +Training logs and checkpoints are saved in the folder set by ``--exp-dir`` +(default ``conformer_ctc/exp``). You will find the following files in that directory: - ``epoch-0.pt``, ``epoch-1.pt``, ... @@ -223,10 +227,10 @@ You will find the following files in that directory: To stop uploading, press Ctrl-C. - New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/qvNrx6JIQAaN5Ly3uQotrg/ + New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/WE1DocDqRRCOSAgmGyClhg/ - [2021-09-12T16:41:16] Started scanning logdir. - [2021-09-12T16:42:17] Total uploaded: 125346 scalars, 0 tensors, 0 binary objects + [2021-11-16T10:51:46] Started scanning logdir. + [2021-11-16T10:52:32] Total uploaded: 111606 scalars, 0 tensors, 0 binary objects Listening for new data in logdir... Note there is a URL in the above output, click it and you will see @@ -236,7 +240,7 @@ You will find the following files in that directory: :width: 600 :alt: TensorBoard screenshot :align: center - :target: https://tensorboard.dev/experiment/qvNrx6JIQAaN5Ly3uQotrg/ + :target: https://tensorboard.dev/experiment/WE1DocDqRRCOSAgmGyClhg/ TensorBoard screenshot. @@ -307,9 +311,9 @@ The commonly used options are: .. code-block:: $ cd egs/aishell/ASR - $ ./conformer_ctc/decode.py --method attention-decoder --max-duration 30 --lattice-score-scale 0.5 + $ ./conformer_ctc/decode.py --method attention-decoder --max-duration 30 --nbest-scale 0.5 - - ``--lattice-score-scale`` + - ``--nbest-scale`` It is used to scale down lattice scores so that there are more unique paths for rescoring. @@ -403,7 +407,7 @@ After downloading, you will have the following files: - ``exp/pretrained.pt`` It contains pre-trained model parameters, obtained by averaging - checkpoints from ``epoch-18.pt`` to ``epoch-40.pt``. + checkpoints from ``epoch-25.pt`` to ``epoch-84.pt``. Note: We have removed optimizer ``state_dict`` to reduce file size. - ``test_waves/*.wav`` @@ -483,7 +487,7 @@ The command to run HLG decoding is: --method 1best \ ./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0121.wav \ ./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0122.wav \ - ./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0123.wav + ./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0123.wav The output is given below: @@ -527,7 +531,7 @@ The command to run HLG decoding + attention decoder rescoring is: --method attention-decoder \ ./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0121.wav \ ./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0122.wav \ - ./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0123.wav + ./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0123.wav The output is below: diff --git a/docs/source/recipes/aishell/images/aishell-conformer-ctc-tensorboard-log.jpg b/docs/source/recipes/aishell/images/aishell-conformer-ctc-tensorboard-log.jpg index 47f7d18a7..7ea2e8369 100644 Binary files a/docs/source/recipes/aishell/images/aishell-conformer-ctc-tensorboard-log.jpg and b/docs/source/recipes/aishell/images/aishell-conformer-ctc-tensorboard-log.jpg differ diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index 5cbe5d213..7b9650ae1 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -1,16 +1,16 @@ ## Results ### Aishell training results (Conformer-CTC) -#### 2021-09-13 +#### 2021-11-16 (Wei Kang): Result of https://github.com/k2-fsa/icefall/pull/30 Pretrained model is available at https://huggingface.co/pkufool/icefall_asr_aishell_conformer_ctc -The best decoding results (CER) are listed below, we got this results by averaging models from epoch 23 to 40, and using `attention-decoder` decoder with num_paths equals to 100. +The best decoding results (CER) are listed below, we got this results by averaging models from epoch 25 to 84, and using `attention-decoder` decoder with num_paths equals to 100. ||test| |--|--| -|CER| 4.74% | +|CER| 4.26% | To get more unique paths, we scaled the lattice.scores with 0.5 (see https://github.com/k2-fsa/icefall/pull/10#discussion_r690951662 for more details), we searched the lm_score_scale and attention_score_scale for best results, the scales that produced the CER above are also listed below. @@ -27,17 +27,18 @@ cd icefall cd egs/aishell/ASR ./prepare.sh -export CUDA_VISIBLE_DEVICES="0,1" -python conformer_ctc/train.py --bucketing-sampler False \ - --concatenate-cuts False \ +export CUDA_VISIBLE_DEVICES="0,1,2,3" +python conformer_ctc/train.py --bucketing-sampler True \ --max-duration 200 \ - --world-size 2 + --start-epoch 0 \ + --num-epoch 90 \ + --world-size 4 -python conformer_ctc/decode.py --lattice-score-scale 0.5 \ - --epoch 40 \ - --avg 18 \ +python conformer_ctc/decode.py --nbest-scale 0.5 \ + --epoch 84 \ + --avg 25 \ --method attention-decoder \ - --max-duration 50 \ + --max-duration 20 \ --num-paths 100 ``` @@ -53,4 +54,3 @@ The best decoding results (CER) are listed below, we got this results by averagi ||test| |--|--| |CER| 10.16% | - diff --git a/egs/aishell/ASR/conformer_ctc/decode.py b/egs/aishell/ASR/conformer_ctc/decode.py index ee2f31483..58ce39cca 100755 --- a/egs/aishell/ASR/conformer_ctc/decode.py +++ b/egs/aishell/ASR/conformer_ctc/decode.py @@ -77,6 +77,8 @@ def get_parser(): default="attention-decoder", help="""Decoding method. Supported values are: + - (0) ctc-decoding. Use CTC decoding. It maps the tokens ids to + tokens using token symbol tabel directly. - (1) 1best. Extract the best path from the decoding lattice as the decoding result. - (2) nbest. Extract n paths from the decoding lattice; the path diff --git a/egs/aishell/ASR/conformer_ctc/pretrained.py b/egs/aishell/ASR/conformer_ctc/pretrained.py index 846681f00..4459c10c6 100755 --- a/egs/aishell/ASR/conformer_ctc/pretrained.py +++ b/egs/aishell/ASR/conformer_ctc/pretrained.py @@ -34,7 +34,7 @@ from icefall.decode import ( one_best_decoding, rescore_with_attention_decoder, ) -from icefall.utils import AttributeDict, get_texts +from icefall.utils import AttributeDict, get_env_info, get_texts def get_parser(): @@ -52,14 +52,21 @@ def get_parser(): ) parser.add_argument( - "--words-file", + "--tokens-file", type=str, - required=True, - help="Path to words.txt", + help="Path to tokens.txt" "Used only when method is ctc-decoding", ) parser.add_argument( - "--HLG", type=str, required=True, help="Path to HLG.pt." + "--words-file", + type=str, + help="Path to words.txt" "Used when method is NOT ctc-decoding", + ) + + parser.add_argument( + "--HLG", + type=str, + help="Path to HLG.pt." "Used when method is NOT ctc-decoding", ) parser.add_argument( @@ -68,6 +75,8 @@ def get_parser(): default="1best", help="""Decoding method. Possible values are: + (0) ctc-decoding - Use ctc decoding. It maps the tokens ids to tokens + using the token symbol table directly. (1) 1best - Use the best path as decoding output. Only the transformer encoder output is used for decoding. We call it HLG decoding. @@ -111,7 +120,7 @@ def get_parser(): ) parser.add_argument( - "--lattice-score-scale", + "--nbest-scale", type=float, default=0.5, help=""" @@ -125,7 +134,7 @@ def get_parser(): parser.add_argument( "--sos-id", - type=float, + type=int, default=1, help=""" Used only when method is attention-decoder. @@ -135,7 +144,7 @@ def get_parser(): parser.add_argument( "--eos-id", - type=float, + type=int, default=1, help=""" Used only when method is attention-decoder. @@ -143,6 +152,13 @@ def get_parser(): """, ) + parser.add_argument( + "--num_classes", + type=int, + default=4336, + help="The Vocab size.", + ) + parser.add_argument( "sound_files", type=str, @@ -160,7 +176,6 @@ def get_params() -> AttributeDict: params = AttributeDict( { "sample_rate": 16000, - "num_classes": 4336, # parameters for conformer "subsampling_factor": 4, "feature_dim": 80, @@ -175,6 +190,7 @@ def get_params() -> AttributeDict: "min_active_states": 30, "max_active_states": 10000, "use_double_scores": True, + "env_info": get_env_info(), } ) return params @@ -212,6 +228,11 @@ def main(): params.update(vars(args)) logging.info(f"{params}") + if args.method != "attention-decoder": + # to save memory as the attention decoder + # will not be used + params.num_decoder_layers = 0 + device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", 0) @@ -231,17 +252,10 @@ def main(): ) checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"]) + model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() - logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) - HLG = HLG.to(device) - if not hasattr(HLG, "lm_scores"): - # For whole-lattice-rescoring and attention-decoder - HLG.lm_scores = HLG.scores.clone() - logging.info("Constructing Fbank computer") opts = kaldifeat.FbankOptions() opts.device = device @@ -275,41 +289,79 @@ def main(): dtype=torch.int32, ) - lattice = get_lattice( - nnet_output=nnet_output, - HLG=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + token_sym_table = k2.SymbolTable.from_file(params.tokens_file) + max_token_id = params.num_classes - 1 + + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) - if params.method == "1best": - logging.info("Use HLG decoding") best_path = one_best_decoding( lattice=lattice, use_double_scores=params.use_double_scores ) - elif params.method == "attention-decoder": - logging.info("Use HLG + attention decoder rescoring") - best_path_dict = rescore_with_attention_decoder( - lattice=lattice, - num_paths=params.num_paths, - model=model, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - sos_id=params.sos_id, - eos_id=params.eos_id, - scale=params.lattice_score_scale, - ngram_lm_scale=params.ngram_lm_scale, - attention_scale=params.attention_decoder_scale, - ) - best_path = next(iter(best_path_dict.values())) + token_ids = get_texts(best_path) + hyps = [[token_sym_table[i] for i in ids] for ids in token_ids] + hyps = [s.split() for s in hyps] + elif params.method in ["1best", "attention-decoder"]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() - hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + lattice = get_lattice( + nnet_output=nnet_output, + HLG=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + elif params.method == "attention-decoder": + logging.info("Use HLG + attention decoder rescoring") + best_path_dict = rescore_with_attention_decoder( + lattice=lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=params.sos_id, + eos_id=params.eos_id, + scale=params.lattice_score_scale, + ngram_lm_scale=params.ngram_lm_scale, + attention_scale=params.attention_decoder_scale, + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") s = "\n" for filename, hyp in zip(params.sound_files, hyps): diff --git a/egs/aishell/ASR/local/compute_fbank_aishell.py b/egs/aishell/ASR/local/compute_fbank_aishell.py index 77293f772..b3b9e7681 100755 --- a/egs/aishell/ASR/local/compute_fbank_aishell.py +++ b/egs/aishell/ASR/local/compute_fbank_aishell.py @@ -23,6 +23,7 @@ It looks for manifests in the directory data/manifests. The generated fbank features are saved in data/fbank. """ +import argparse import logging import os from pathlib import Path @@ -43,7 +44,7 @@ torch.set_num_interop_threads(1) def compute_fbank_aishell(num_mel_bins: int = 80): src_dir = Path("data/manifests") - output_dir = Path("data/fbank40") + output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) dataset_parts = ( @@ -106,4 +107,3 @@ if __name__ == "__main__": args = get_args() compute_fbank_aishell(num_mel_bins=args.num_mel_bins) - diff --git a/egs/aishell/ASR/local/compute_fbank_musan.py b/egs/aishell/ASR/local/compute_fbank_musan.py index 0b97fb8c5..e79bdafb1 100755 --- a/egs/aishell/ASR/local/compute_fbank_musan.py +++ b/egs/aishell/ASR/local/compute_fbank_musan.py @@ -23,6 +23,7 @@ It looks for manifests in the directory data/manifests. The generated fbank features are saved in data/fbank. """ +import argparse import logging import os from pathlib import Path @@ -43,7 +44,7 @@ torch.set_num_interop_threads(1) def compute_fbank_musan(num_mel_bins: int = 80): src_dir = Path("data/manifests") - output_dir = Path("data/fbank40") + output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) dataset_parts = ( @@ -86,6 +87,7 @@ def compute_fbank_musan(num_mel_bins: int = 80): ) musan_cuts.to_json(musan_cuts_path) + def get_args(): parser = argparse.ArgumentParser() parser.add_argument( @@ -106,4 +108,3 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() compute_fbank_musan(num_mel_bins=args.num_mel_bins) - diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index f70e89c65..584e50cb0 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -69,7 +69,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then # |-- lexicon.txt # `-- speaker.info - if [ ! -d $dl_dir/aishell/wav ]; then + if [ ! -d $dl_dir/aishell/data_aishell/wav ]; then lhotse download aishell $dl_dir fi @@ -133,7 +133,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then cat $dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt | cut -d " " -f 2- | sed -e 's/[ \t\r\n]*//g' > data/lang_char/text - + if [ ! -f data/lang_char/L_disambig.pt ]; then ./local/prepare_char.py fi @@ -160,4 +160,3 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then ./local/compile_hlg.py --lang-dir data/lang_phone ./local/compile_hlg.py --lang-dir data/lang_char fi -