From 9d0cc9d8297a721ee7763e87bbd28cc24a289b18 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 20 Aug 2021 11:53:37 +0800 Subject: [PATCH] Support computing nbest oracle WER. (#10) * Support computing nbest oracle WER. * Add scale to all nbest based decoding/rescoring methods. * Add script to run pretrained models. * Use torchaudio to extract features. * Support decoding multiple files at the same time. Also, use kaldifeat for feature extraction. * Support decoding with LM rescoring and attention-decoder rescoring. * Minor fixes. * Replace scale with lattice-score-scale. * Add usage example with a provided pretrained model. --- egs/librispeech/ASR/conformer_ctc/README.md | 339 +++++++++++++++++ egs/librispeech/ASR/conformer_ctc/decode.py | 37 +- .../ASR/conformer_ctc/pretrained.py | 350 ++++++++++++++++++ egs/librispeech/ASR/conformer_ctc/train.py | 2 +- icefall/decode.py | 188 +++++++++- 5 files changed, 893 insertions(+), 23 deletions(-) create mode 100644 egs/librispeech/ASR/conformer_ctc/README.md create mode 100755 egs/librispeech/ASR/conformer_ctc/pretrained.py diff --git a/egs/librispeech/ASR/conformer_ctc/README.md b/egs/librispeech/ASR/conformer_ctc/README.md new file mode 100644 index 000000000..a02ec35af --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/README.md @@ -0,0 +1,339 @@ + +# How to use a pre-trained model to transcribe a sound file or multiple sound files + +You need to prepare 4 files: + + - a model checkpoint file, e.g., epoch-20.pt + - HLG.pt, the decoding graph + - words.txt, the word symbol table + - a sound file, whose sampling rate has to be 16 kHz. + Supported formats are those supported by `torchaudio.load()`, + e.g., wav and flac. + +Also, you need to install `kaldifeat`. Please refer to + for installation. + +```bash +./conformer_ctc/pretrained.py --help +``` + +displays the help information. + +## HLG decoding + +Once you have the above files ready and have `kaldifeat` installed, +you can run: + +```bash +./conformer_ctc/pretrained.py \ + --checkpoint /path/to/your/checkpoint.pt \ + --words-file /path/to/words.txt \ + --HLG /path/to/HLG.pt \ + /path/to/your/sound.wav +``` + +and you will see the transcribed result. + +If you want to transcribe multiple files at the same time, you can use: + +```bash +./conformer_ctc/pretrained.py \ + --checkpoint /path/to/your/checkpoint.pt \ + --words-file /path/to/words.txt \ + --HLG /path/to/HLG.pt \ + /path/to/your/sound1.wav \ + /path/to/your/sound2.wav \ + /path/to/your/sound3.wav +``` + +**Note**: This is the fastest decoding method. + +## HLG decoding + LM rescoring + +`./conformer_ctc/pretrained.py` also supports `whole lattice LM rescoring` +and `attention decoder rescoring`. + +To use whole lattice LM rescoring, you also need the following files: + + - G.pt, e.g., `data/lm/G_4_gram.pt` if you have run `./prepare.sh` + +The command to run decoding with LM rescoring is: + +```bash +./conformer_ctc/pretrained.py \ + --checkpoint /path/to/your/checkpoint.pt \ + --words-file /path/to/words.txt \ + --HLG /path/to/HLG.pt \ + --method whole-lattice-rescoring \ + --G data/lm/G_4_gram.pt \ + --ngram-lm-scale 0.8 \ + /path/to/your/sound1.wav \ + /path/to/your/sound2.wav \ + /path/to/your/sound3.wav +``` + +## HLG Decoding + LM rescoring + attention decoder rescoring + +To use attention decoder for rescoring, you need the following extra information: + + - sos token ID + - eos token ID + +The command to run decoding with attention decoder rescoring is: + +```bash +./conformer_ctc/pretrained.py \ + --checkpoint /path/to/your/checkpoint.pt \ + --words-file /path/to/words.txt \ + --HLG /path/to/HLG.pt \ + --method attention-decoder \ + --G data/lm/G_4_gram.pt \ + --ngram-lm-scale 1.3 \ + --attention-decoder-scale 1.2 \ + --lattice-score-scale 0.5 \ + --num-paths 100 \ + --sos-id 1 \ + --eos-id 1 \ + /path/to/your/sound1.wav \ + /path/to/your/sound2.wav \ + /path/to/your/sound3.wav +``` + +# Decoding with a pretrained model in action + +We have uploaded a pretrained model to + +The following shows the steps about the usage of the provided pretrained model. + +### (1) Download the pretrained model + +```bash +cd /path/to/icefall/egs/librispeech/ASR +mkdir tmp +cd tmp +git clone https://huggingface.co/pkufool/conformer_ctc + +``` + +You will find the following files: + +``` +tmp +`-- conformer_ctc + |-- README.md + |-- data + | |-- lang_bpe + | | |-- HLG.pt + | | |-- bpe.model + | | |-- tokens.txt + | | `-- words.txt + | `-- lm + | `-- G_4_gram.pt + |-- exp + | `-- pretraind.pt + `-- test_wavs + |-- 1089-134686-0001.flac + |-- 1221-135766-0001.flac + |-- 1221-135766-0002.flac + `-- trans.txt + +6 directories, 11 files +``` + +**File descriptions**: + + - `data/lang_bpe/HLG.pt` + + It is the decoding graph. + + - `data/lang_bpe/bpe.model` + + It is a sentencepiece model. You can use it to reproduce our results. + + - `data/lang_bpe/tokens.txt` + + It contains tokens and their IDs, generated from `bpe.model`. + Provided only for convienice so that you can look up the SOS/EOS ID easily. + + - `data/lang_bpe/words.txt` + + It contains words and their IDs. + + - `data/lm/G_4_gram.pt` + + It is a 4-gram LM, useful for LM rescoring. + + - `exp/pretrained.pt` + + It contains pretrained model parameters, obtained by averaging + checkpoints from `epoch-15.pt` to `epoch-34.pt`. + Note: We have removed optimizer `state_dict` to reduce file size. + + - `test_waves/*.flac` + + It contains some test sound files from LibriSpeech `test-clean` dataset. + + - `test_waves/trans.txt` + + It contains the reference transcripts for the sound files in `test_waves/`. + +The information of the test sound files is listed below: + +``` +$ soxi tmp/conformer_ctc/test_wavs/*.flac + +Input File : 'tmp/conformer_ctc/test_wavs/1089-134686-0001.flac' +Channels : 1 +Sample Rate : 16000 +Precision : 16-bit +Duration : 00:00:06.62 = 106000 samples ~ 496.875 CDDA sectors +File Size : 116k +Bit Rate : 140k +Sample Encoding: 16-bit FLAC + +Input File : 'tmp/conformer_ctc/test_wavs/1221-135766-0001.flac' +Channels : 1 +Sample Rate : 16000 +Precision : 16-bit +Duration : 00:00:16.71 = 267440 samples ~ 1253.62 CDDA sectors +File Size : 343k +Bit Rate : 164k +Sample Encoding: 16-bit FLAC + +Input File : 'tmp/conformer_ctc/test_wavs/1221-135766-0002.flac' +Channels : 1 +Sample Rate : 16000 +Precision : 16-bit +Duration : 00:00:04.83 = 77200 samples ~ 361.875 CDDA sectors +File Size : 105k +Bit Rate : 174k +Sample Encoding: 16-bit FLAC + +Total Duration of 3 files: 00:00:28.16 +``` + +### (2) Use HLG decoding + +```bash +cd /path/to/icefall/egs/librispeech/ASR + +./conformer_ctc/pretrained.py \ + --checkpoint ./tmp/conformer_ctc/exp/pretraind.pt \ + --words-file ./tmp/conformer_ctc/data/lang_bpe/words.txt \ + --HLG ./tmp/conformer_ctc/data/lang_bpe/HLG.pt \ + ./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac \ + ./tmp/conformer_ctc/test_wavs/1221-135766-0001.flac \ + ./tmp/conformer_ctc/test_wavs/1221-135766-0002.flac +``` + +The output is given below: + +``` +2021-08-20 11:03:05,712 INFO [pretrained.py:217] device: cuda:0 +2021-08-20 11:03:05,712 INFO [pretrained.py:219] Creating model +2021-08-20 11:03:11,345 INFO [pretrained.py:238] Loading HLG from ./tmp/conformer_ctc/data/lang_bpe/HLG.pt +2021-08-20 11:03:18,442 INFO [pretrained.py:255] Constructing Fbank computer +2021-08-20 11:03:18,444 INFO [pretrained.py:265] Reading sound files: ['./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/conformer_ctc/test_wavs/1221-135766-0002.flac'] +2021-08-20 11:03:18,507 INFO [pretrained.py:271] Decoding started +2021-08-20 11:03:18,795 INFO [pretrained.py:300] Use HLG decoding +2021-08-20 11:03:19,149 INFO [pretrained.py:339] +./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac: +AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS + +./tmp/conformer_ctc/test_wavs/1221-135766-0001.flac: +GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONOURED +BOSOM TO CONNECT HER PARENT FOR EVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN + +./tmp/conformer_ctc/test_wavs/1221-135766-0002.flac: +YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION + + +2021-08-20 11:03:19,149 INFO [pretrained.py:341] Decoding Done +``` + +### (3) Use HLG decoding + LM rescoring + +```bash +./conformer_ctc/pretrained.py \ + --checkpoint ./tmp/conformer_ctc/exp/pretraind.pt \ + --words-file ./tmp/conformer_ctc/data/lang_bpe/words.txt \ + --HLG ./tmp/conformer_ctc/data/lang_bpe/HLG.pt \ + --method whole-lattice-rescoring \ + --G ./tmp/conformer_ctc/data/lm/G_4_gram.pt \ + --ngram-lm-scale 0.8 \ + ./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac \ + ./tmp/conformer_ctc/test_wavs/1221-135766-0001.flac \ + ./tmp/conformer_ctc/test_wavs/1221-135766-0002.flac +``` + +The output is: + +``` +2021-08-20 11:12:17,565 INFO [pretrained.py:217] device: cuda:0 +2021-08-20 11:12:17,565 INFO [pretrained.py:219] Creating model +2021-08-20 11:12:23,728 INFO [pretrained.py:238] Loading HLG from ./tmp/conformer_ctc/data/lang_bpe/HLG.pt +2021-08-20 11:12:30,035 INFO [pretrained.py:246] Loading G from ./tmp/conformer_ctc/data/lm/G_4_gram.pt +2021-08-20 11:13:10,779 INFO [pretrained.py:255] Constructing Fbank computer +2021-08-20 11:13:10,787 INFO [pretrained.py:265] Reading sound files: ['./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/conformer_ctc/test_wavs/1221-135766-0002.flac'] +2021-08-20 11:13:10,798 INFO [pretrained.py:271] Decoding started +2021-08-20 11:13:11,085 INFO [pretrained.py:305] Use HLG decoding + LM rescoring +2021-08-20 11:13:11,736 INFO [pretrained.py:339] +./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac: +AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS + +./tmp/conformer_ctc/test_wavs/1221-135766-0001.flac: +GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONOURED +BOSOM TO CONNECT HER PARENT FOR EVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN + +./tmp/conformer_ctc/test_wavs/1221-135766-0002.flac: +YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION + + +2021-08-20 11:13:11,737 INFO [pretrained.py:341] Decoding Done +``` + +### (4) Use HLG decoding + LM rescoring + attention decoder rescoring + +```bash +./conformer_ctc/pretrained.py \ + --checkpoint ./tmp/conformer_ctc/exp/pretraind.pt \ + --words-file ./tmp/conformer_ctc/data/lang_bpe/words.txt \ + --HLG ./tmp/conformer_ctc/data/lang_bpe/HLG.pt \ + --method attention-decoder \ + --G ./tmp/conformer_ctc/data/lm/G_4_gram.pt \ + --ngram-lm-scale 1.3 \ + --attention-decoder-scale 1.2 \ + --lattice-score-scale 0.5 \ + --num-paths 100 \ + --sos-id 1 \ + --eos-id 1 \ + ./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac \ + ./tmp/conformer_ctc/test_wavs/1221-135766-0001.flac \ + ./tmp/conformer_ctc/test_wavs/1221-135766-0002.flac +``` + +The output is: + +``` +2021-08-20 11:19:11,397 INFO [pretrained.py:217] device: cuda:0 +2021-08-20 11:19:11,397 INFO [pretrained.py:219] Creating model +2021-08-20 11:19:17,354 INFO [pretrained.py:238] Loading HLG from ./tmp/conformer_ctc/data/lang_bpe/HLG.pt +2021-08-20 11:19:24,615 INFO [pretrained.py:246] Loading G from ./tmp/conformer_ctc/data/lm/G_4_gram.pt +2021-08-20 11:20:04,576 INFO [pretrained.py:255] Constructing Fbank computer +2021-08-20 11:20:04,584 INFO [pretrained.py:265] Reading sound files: ['./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/conformer_ctc/test_wavs/1221-135766-0002.flac'] +2021-08-20 11:20:04,595 INFO [pretrained.py:271] Decoding started +2021-08-20 11:20:04,854 INFO [pretrained.py:313] Use HLG + LM rescoring + attention decoder rescoring +2021-08-20 11:20:05,805 INFO [pretrained.py:339] +./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac: +AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS + +./tmp/conformer_ctc/test_wavs/1221-135766-0001.flac: +GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONOURED +BOSOM TO CONNECT HER PARENT FOR EVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN + +./tmp/conformer_ctc/test_wavs/1221-135766-0002.flac: +YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION + + +2021-08-20 11:20:05,805 INFO [pretrained.py:341] Decoding Done +``` diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 1080ac031..c17a8b284 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -21,6 +21,7 @@ from icefall.dataset.librispeech import LibriSpeechAsrDataModule from icefall.decode import ( get_lattice, nbest_decoding, + nbest_oracle, one_best_decoding, rescore_with_attention_decoder, rescore_with_n_best_list, @@ -56,6 +57,18 @@ def get_parser(): "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", ) + + parser.add_argument( + "--lattice-score-scale", + type=float, + default=1.0, + help="The scale to be applied to `lattice.scores`." + "It's needed if you use any kinds of n-best based rescoring. " + "Currently, it is used when the decoding method is: nbest, " + "nbest-rescoring, attention-decoder, and nbest-oracle. " + "A smaller value results in more unique paths.", + ) + return parser @@ -85,10 +98,14 @@ def get_params() -> AttributeDict: # - nbest-rescoring # - whole-lattice-rescoring # - attention-decoder + # - nbest-oracle + # "method": "nbest", + # "method": "nbest-rescoring", # "method": "whole-lattice-rescoring", "method": "attention-decoder", + # "method": "nbest-oracle", # num_paths is used when method is "nbest", "nbest-rescoring", - # and attention-decoder + # attention-decoder, and nbest-oracle "num_paths": 100, } ) @@ -179,6 +196,19 @@ def decode_one_batch( subsampling_factor=params.subsampling_factor, ) + if params.method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is slightly worse than that of rescored lattices. + return nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + lexicon=lexicon, + scale=params.lattice_score_scale, + ) + if params.method in ["1best", "nbest"]: if params.method == "1best": best_path = one_best_decoding( @@ -190,8 +220,9 @@ def decode_one_batch( lattice=lattice, num_paths=params.num_paths, use_double_scores=params.use_double_scores, + scale=params.lattice_score_scale, ) - key = f"no_rescore-{params.num_paths}" + key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" hyps = get_texts(best_path) hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] @@ -212,6 +243,7 @@ def decode_one_batch( G=G, num_paths=params.num_paths, lm_scale_list=lm_scale_list, + scale=params.lattice_score_scale, ) elif params.method == "whole-lattice-rescoring": best_path_dict = rescore_with_whole_lattice( @@ -231,6 +263,7 @@ def decode_one_batch( memory_key_padding_mask=memory_key_padding_mask, sos_id=sos_id, eos_id=eos_id, + scale=params.lattice_score_scale, ) else: assert False, f"Unsupported decoding method: {params.method}" diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py new file mode 100755 index 000000000..fbdeb39b5 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -0,0 +1,350 @@ +#!/usr/bin/env python3 + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from conformer import Conformer +from torch.nn.utils.rnn import pad_sequence + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_attention_decoder, + rescore_with_whole_lattice, +) +from icefall.utils import AttributeDict, get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--words-file", + type=str, + required=True, + help="Path to words.txt", + ) + + parser.add_argument( + "--HLG", type=str, required=True, help="Path to HLG.pt." + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + (3) attention-decoder - Extract n paths from he rescored + lattice and use the transformer attention decoder for + rescoring. + We call it HLG decoding + n-gram LM rescoring + attention + decoder rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or attention-decoder. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=1.3, + help=""" + Used only when method is whole-lattice-rescoring and attention-decoder. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--attention-decoder-scale", + type=float, + default=1.2, + help=""" + Used only when method is attention-decoder. + It specifies the scale for attention decoder scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--lattice-score-scale", + type=float, + default=0.5, + help=""" + Used only when method is attention-decoder. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--sos-id", + type=float, + default=1, + help=""" + Used only when method is attention-decoder. + It specifies ID for the SOS token. + """, + ) + + parser.add_argument( + "--eos-id", + type=float, + default=1, + help=""" + Used only when method is attention-decoder. + It specifies ID for the EOS token. + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "feature_dim": 80, + "nhead": 8, + "num_classes": 5000, + "sample_rate": 16000, + "attention_dim": 512, + "subsampling_factor": 4, + "num_decoder_layers": 6, + "vgg_frontend": False, + "is_espnet_structure": True, + "mmi_loss": False, + "use_feat_batchnorm": True, + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + params.update(vars(args)) + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=params.num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=params.vgg_frontend, + is_espnet_structure=params.is_espnet_structure, + mmi_loss=params.mmi_loss, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"]) + model.to(device) + model.eval() + + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in ["whole-lattice-rescoring", "attention-decoder"]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + G.lm_scores = G.scores.clone() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info(f"Decoding started") + features = fbank(waves) + + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) + + # Note: We don't use key padding mask for attention during decoding + with torch.no_grad(): + nnet_output, memory, memory_key_padding_mask = model(features) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], + dtype=torch.int32, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + HLG=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "attention-decoder": + logging.info("Use HLG + LM rescoring + attention decoder rescoring") + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None + ) + best_path_dict = rescore_with_attention_decoder( + lattice=rescored_lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=params.sos_id, + eos_id=params.eos_id, + scale=params.lattice_score_scale, + ngram_lm_scale=params.ngram_lm_scale, + attention_scale=params.attention_decoder_scale, + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info(f"Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 36464439b..d3ea8efb0 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -128,7 +128,7 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { - "exp_dir": Path("conformer_ctc/exp_new"), + "exp_dir": Path("conformer_ctc/exp"), "lang_dir": Path("data/lang_bpe"), "feature_dim": 80, "weight_decay": 1e-6, diff --git a/icefall/decode.py b/icefall/decode.py index 2d3e1ed56..49d642f1c 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -2,9 +2,42 @@ import logging from typing import Dict, List, Optional, Tuple, Union import k2 +import kaldialign import torch import torch.nn as nn +from icefall.lexicon import Lexicon + + +def _get_random_paths( + lattice: k2.Fsa, + num_paths: int, + use_double_scores: bool = True, + scale: float = 1.0, +): + """ + Args: + lattice: + The decoding lattice, returned by :func:`get_lattice`. + num_paths: + It specifies the size `n` in n-best. Note: Paths are selected randomly + and those containing identical word sequences are remove dand only one + of them is kept. + use_double_scores: + True to use double precision floating point in the computation. + False to use single precision. + scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + Returns: + Return a k2.RaggedInt with 3 axes [seq][path][arc_pos] + """ + saved_scores = lattice.scores.clone() + lattice.scores *= scale + path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True) + lattice.scores = saved_scores + return path + def _intersect_device( a_fsas: k2.Fsa, @@ -129,7 +162,10 @@ def one_best_decoding( def nbest_decoding( - lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True + lattice: k2.Fsa, + num_paths: int, + use_double_scores: bool = True, + scale: float = 1.0, ) -> k2.Fsa: """It implements something like CTC prefix beam search using n-best lists. @@ -152,12 +188,18 @@ def nbest_decoding( use_double_scores: True to use double precision floating point in the computation. False to use single precision. + scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. Returns: An FsaVec containing linear FSAs. """ - # First, extract `num_paths` paths for each sequence. - # path is a k2.RaggedInt with axes [seq][path][arc_pos] - path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True) + path = _get_random_paths( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + scale=scale, + ) # word_seq is a k2.RaggedInt sharing the same shape as `path` # but it contains word IDs. Note that it also contains 0s and -1s. @@ -320,7 +362,11 @@ def compute_am_and_lm_scores( def rescore_with_n_best_list( - lattice: k2.Fsa, G: k2.Fsa, num_paths: int, lm_scale_list: List[float] + lattice: k2.Fsa, + G: k2.Fsa, + num_paths: int, + lm_scale_list: List[float], + scale: float = 1.0, ) -> Dict[str, k2.Fsa]: """Decode using n-best list with LM rescoring. @@ -342,6 +388,9 @@ def rescore_with_n_best_list( It is the size `n` in `n-best` list. lm_scale_list: A list containing lm_scale values. + scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. Returns: A dict of FsaVec, whose key is an lm_scale and the value is the best decoding path for each sequence in the lattice. @@ -356,9 +405,12 @@ def rescore_with_n_best_list( assert G.device == device assert hasattr(G, "aux_labels") is False - # First, extract `num_paths` paths for each sequence. - # path is a k2.RaggedInt with axes [seq][path][arc_pos] - path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True) + path = _get_random_paths( + lattice=lattice, + num_paths=num_paths, + use_double_scores=True, + scale=scale, + ) # word_seq is a k2.RaggedInt sharing the same shape as `path` # but it contains word IDs. Note that it also contains 0s and -1s. @@ -376,7 +428,7 @@ def rescore_with_n_best_list( # # num_repeats is also a k2.RaggedInt with 2 axes containing the # multiplicities of each path. - # num_repeats.num_elements() == unique_word_seqs.num_elements() + # num_repeats.num_elements() == unique_word_seqs.tot_size(1) # # Since k2.ragged.unique_sequences will reorder paths within a seq, # `new2old` is a 1-D torch.Tensor mapping from the output path index @@ -494,6 +546,8 @@ def rescore_with_whole_lattice( del lattice.lm_scores assert hasattr(lattice, "lm_scores") is False + assert hasattr(G_with_epsilon_loops, "lm_scores") + # Now, lattice.scores contains only am_scores # inv_lattice has word IDs as labels. @@ -549,14 +603,88 @@ def rescore_with_whole_lattice( return ans +def nbest_oracle( + lattice: k2.Fsa, + num_paths: int, + ref_texts: List[str], + lexicon: Lexicon, + scale: float = 1.0, +) -> Dict[str, List[List[int]]]: + """Select the best hypothesis given a lattice and a reference transcript. + + The basic idea is to extract n paths from the given lattice, unique them, + and select the one that has the minimum edit distance with the corresponding + reference transcript as the decoding output. + + The decoding result returned from this function is the best result that + we can obtain using n-best decoding with all kinds of rescoring techniques. + + Args: + lattice: + An FsaVec. It can be the return value of :func:`get_lattice`. + Note: We assume its aux_labels contain word IDs. + num_paths: + The size of `n` in n-best. + ref_texts: + A list of reference transcript. Each entry contains space(s) + separated words + lexicon: + It is used to convert word IDs to word symbols. + scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + Return: + Return a dict. Its key contains the information about the parameters + when calling this function, while its value contains the decoding output. + `len(ans_dict) == len(ref_texts)` + """ + path = _get_random_paths( + lattice=lattice, + num_paths=num_paths, + use_double_scores=True, + scale=scale, + ) + + word_seq = k2.index(lattice.aux_labels, path) + word_seq = k2.ragged.remove_values_leq(word_seq, 0) + unique_word_seq, _, _ = k2.ragged.unique_sequences( + word_seq, need_num_repeats=False, need_new2old_indexes=False + ) + unique_word_ids = k2.ragged.to_list(unique_word_seq) + assert len(unique_word_ids) == len(ref_texts) + # unique_word_ids[i] contains all hypotheses of the i-th utterance + + results = [] + for hyps, ref in zip(unique_word_ids, ref_texts): + # Note hyps is a list-of-list ints + # Each sublist contains a hypothesis + ref_words = ref.strip().split() + # CAUTION: We don't convert ref_words to ref_words_ids + # since there may exist OOV words in ref_words + best_hyp_words = None + min_error = float("inf") + for hyp_words in hyps: + hyp_words = [lexicon.word_table[i] for i in hyp_words] + this_error = kaldialign.edit_distance(ref_words, hyp_words)["total"] + if this_error < min_error: + min_error = this_error + best_hyp_words = hyp_words + results.append(best_hyp_words) + + return {f"nbest_{num_paths}_scale_{scale}_oracle": results} + + def rescore_with_attention_decoder( lattice: k2.Fsa, num_paths: int, model: nn.Module, memory: torch.Tensor, - memory_key_padding_mask: torch.Tensor, + memory_key_padding_mask: Optional[torch.Tensor], sos_id: int, eos_id: int, + scale: float = 1.0, + ngram_lm_scale: Optional[float] = None, + attention_scale: Optional[float] = None, ) -> Dict[str, k2.Fsa]: """This function extracts n paths from the given lattice and uses an attention decoder to rescore them. The path with the highest @@ -580,6 +708,13 @@ def rescore_with_attention_decoder( The token ID for SOS. eos_id: The token ID for EOS. + scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + ngram_lm_scale: + Optional. It specifies the scale for n-gram LM scores. + attention_scale: + Optional. It specifies the scale for attention decoder scores. Returns: A dict of FsaVec, whose key contains a string ngram_lm_scale_attention_scale and the value is the @@ -587,7 +722,12 @@ def rescore_with_attention_decoder( """ # First, extract `num_paths` paths for each sequence. # path is a k2.RaggedInt with axes [seq][path][arc_pos] - path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True) + path = _get_random_paths( + lattice=lattice, + num_paths=num_paths, + use_double_scores=True, + scale=scale, + ) # word_seq is a k2.RaggedInt sharing the same shape as `path` # but it contains word IDs. Note that it also contains 0s and -1s. @@ -605,7 +745,7 @@ def rescore_with_attention_decoder( # # num_repeats is also a k2.RaggedInt with 2 axes containing the # multiplicities of each path. - # num_repeats.num_elements() == unique_word_seqs.num_elements() + # num_repeats.num_elements() == unique_word_seqs.tot_size(1) # # Since k2.ragged.unique_sequences will reorder paths within a seq, # `new2old` is a 1-D torch.Tensor mapping from the output path index @@ -662,11 +802,13 @@ def rescore_with_attention_decoder( path_to_seq_map_long = path_to_seq_map.to(torch.long) expanded_memory = memory.index_select(1, path_to_seq_map_long) - expanded_memory_key_padding_mask = memory_key_padding_mask.index_select( - 0, path_to_seq_map_long - ) + if memory_key_padding_mask is not None: + expanded_memory_key_padding_mask = memory_key_padding_mask.index_select( + 0, path_to_seq_map_long + ) + else: + expanded_memory_key_padding_mask = None - # TODO: pass the sos_token_id and eos_token_id via function arguments nll = model.decoder_nll( memory=expanded_memory, memory_key_padding_mask=expanded_memory_key_padding_mask, @@ -681,11 +823,17 @@ def rescore_with_attention_decoder( assert attention_scores.ndim == 1 assert attention_scores.numel() == num_word_seqs - ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] - ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + if ngram_lm_scale is None: + ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + else: + ngram_lm_scale_list = [ngram_lm_scale] - attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] - attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + if attention_scale is None: + attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + else: + attention_scale_list = [attention_scale] path_2axes = k2.ragged.remove_axis(path, 0)