diff --git a/README.md b/README.md index 91c1f67a9..b49a7f04c 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,13 @@ It should print the path to `icefall`. At present, only LibriSpeech recipe is provided. Please follow [egs/librispeech/ASR/README.md][LibriSpeech] to run it. +## Use Pre-trained models + +See [egs/librispeech/ASR/conformer_ctc/README.md](egs/librispeech/ASR/conformer_ctc/README.md) +for how to use pre-trained models. +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1huyupXAcHsUrKaWfI83iMEJ6J0Nh0213?usp=sharing) + + [LibriSpeech]: egs/librispeech/ASR/README.md [k2-install]: https://k2.readthedocs.io/en/latest/installation/index.html# [k2]: https://github.com/k2-fsa/k2 diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md new file mode 100644 index 000000000..159147a3e --- /dev/null +++ b/egs/librispeech/ASR/RESULTS.md @@ -0,0 +1,23 @@ +## Results + +### LibriSpeech BPE training results (Conformer-CTC) +#### 2021-08-19 +(Wei Kang): Result of https://github.com/k2-fsa/icefall/pull/13 + +TensorBoard log is available at https://tensorboard.dev/experiment/GnRzq8WWQW62dK4bklXBTg/#scalars + +Pretrained model is available at https://huggingface.co/pkufool/conformer_ctc + +The best decoding results (WER) are listed below, we got this results by averaging models from epoch 15 to 34, and using `attention-decoder` decoder with num_paths equals to 100. + +||test-clean|test-other| +|--|--|--| +|WER| 2.57% | 5.94% | + +To get more unique paths, we scaled the lattice.scores with 0.5 (see https://github.com/k2-fsa/icefall/pull/10#discussion_r690951662 for more details), we searched the lm_score_scale and attention_score_scale for best results, the scales that produced the WER above are also listed below. + +||lm_scale|attention_scale| +|--|--|--| +|test-clean|1.3|1.2| +|test-other|1.2|1.1| + diff --git a/egs/librispeech/ASR/conformer_ctc/README.md b/egs/librispeech/ASR/conformer_ctc/README.md new file mode 100644 index 000000000..130d21351 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/README.md @@ -0,0 +1,351 @@ + +# How to use a pre-trained model to transcribe a sound file or multiple sound files + +(See the bottom of this document for the link to a colab notebook.) + +You need to prepare 4 files: + + - a model checkpoint file, e.g., epoch-20.pt + - HLG.pt, the decoding graph + - words.txt, the word symbol table + - a sound file, whose sampling rate has to be 16 kHz. + Supported formats are those supported by `torchaudio.load()`, + e.g., wav and flac. + +Also, you need to install `kaldifeat`. Please refer to + for installation. + +```bash +./conformer_ctc/pretrained.py --help +``` + +displays the help information. + +## HLG decoding + +Once you have the above files ready and have `kaldifeat` installed, +you can run: + +```bash +./conformer_ctc/pretrained.py \ + --checkpoint /path/to/your/checkpoint.pt \ + --words-file /path/to/words.txt \ + --HLG /path/to/HLG.pt \ + /path/to/your/sound.wav +``` + +and you will see the transcribed result. + +If you want to transcribe multiple files at the same time, you can use: + +```bash +./conformer_ctc/pretrained.py \ + --checkpoint /path/to/your/checkpoint.pt \ + --words-file /path/to/words.txt \ + --HLG /path/to/HLG.pt \ + /path/to/your/sound1.wav \ + /path/to/your/sound2.wav \ + /path/to/your/sound3.wav +``` + +**Note**: This is the fastest decoding method. + +## HLG decoding + LM rescoring + +`./conformer_ctc/pretrained.py` also supports `whole lattice LM rescoring` +and `attention decoder rescoring`. + +To use whole lattice LM rescoring, you also need the following files: + + - G.pt, e.g., `data/lm/G_4_gram.pt` if you have run `./prepare.sh` + +The command to run decoding with LM rescoring is: + +```bash +./conformer_ctc/pretrained.py \ + --checkpoint /path/to/your/checkpoint.pt \ + --words-file /path/to/words.txt \ + --HLG /path/to/HLG.pt \ + --method whole-lattice-rescoring \ + --G data/lm/G_4_gram.pt \ + --ngram-lm-scale 0.8 \ + /path/to/your/sound1.wav \ + /path/to/your/sound2.wav \ + /path/to/your/sound3.wav +``` + +## HLG Decoding + LM rescoring + attention decoder rescoring + +To use attention decoder for rescoring, you need the following extra information: + + - sos token ID + - eos token ID + +The command to run decoding with attention decoder rescoring is: + +```bash +./conformer_ctc/pretrained.py \ + --checkpoint /path/to/your/checkpoint.pt \ + --words-file /path/to/words.txt \ + --HLG /path/to/HLG.pt \ + --method attention-decoder \ + --G data/lm/G_4_gram.pt \ + --ngram-lm-scale 1.3 \ + --attention-decoder-scale 1.2 \ + --lattice-score-scale 0.5 \ + --num-paths 100 \ + --sos-id 1 \ + --eos-id 1 \ + /path/to/your/sound1.wav \ + /path/to/your/sound2.wav \ + /path/to/your/sound3.wav +``` + +# Decoding with a pre-trained model in action + +We have uploaded a pre-trained model to + +The following shows the steps about the usage of the provided pre-trained model. + +### (1) Download the pre-trained model + +```bash +sudo apt-get install git-lfs +cd /path/to/icefall/egs/librispeech/ASR +git lfs install +mkdir tmp +cd tmp +git clone https://huggingface.co/pkufool/conformer_ctc +``` + +**CAUTION**: You have to install `git-lfst` to download the pre-trained model. + +You will find the following files: + +``` +tmp +`-- conformer_ctc + |-- README.md + |-- data + | |-- lang_bpe + | | |-- HLG.pt + | | |-- bpe.model + | | |-- tokens.txt + | | `-- words.txt + | `-- lm + | `-- G_4_gram.pt + |-- exp + | `-- pretraind.pt + `-- test_wavs + |-- 1089-134686-0001.flac + |-- 1221-135766-0001.flac + |-- 1221-135766-0002.flac + `-- trans.txt + +6 directories, 11 files +``` + +**File descriptions**: + + - `data/lang_bpe/HLG.pt` + + It is the decoding graph. + + - `data/lang_bpe/bpe.model` + + It is a sentencepiece model. You can use it to reproduce our results. + + - `data/lang_bpe/tokens.txt` + + It contains tokens and their IDs, generated from `bpe.model`. + Provided only for convienice so that you can look up the SOS/EOS ID easily. + + - `data/lang_bpe/words.txt` + + It contains words and their IDs. + + - `data/lm/G_4_gram.pt` + + It is a 4-gram LM, useful for LM rescoring. + + - `exp/pretrained.pt` + + It contains pre-trained model parameters, obtained by averaging + checkpoints from `epoch-15.pt` to `epoch-34.pt`. + Note: We have removed optimizer `state_dict` to reduce file size. + + - `test_waves/*.flac` + + It contains some test sound files from LibriSpeech `test-clean` dataset. + + - `test_waves/trans.txt` + + It contains the reference transcripts for the sound files in `test_waves/`. + +The information of the test sound files is listed below: + +``` +$ soxi tmp/conformer_ctc/test_wavs/*.flac + +Input File : 'tmp/conformer_ctc/test_wavs/1089-134686-0001.flac' +Channels : 1 +Sample Rate : 16000 +Precision : 16-bit +Duration : 00:00:06.62 = 106000 samples ~ 496.875 CDDA sectors +File Size : 116k +Bit Rate : 140k +Sample Encoding: 16-bit FLAC + +Input File : 'tmp/conformer_ctc/test_wavs/1221-135766-0001.flac' +Channels : 1 +Sample Rate : 16000 +Precision : 16-bit +Duration : 00:00:16.71 = 267440 samples ~ 1253.62 CDDA sectors +File Size : 343k +Bit Rate : 164k +Sample Encoding: 16-bit FLAC + +Input File : 'tmp/conformer_ctc/test_wavs/1221-135766-0002.flac' +Channels : 1 +Sample Rate : 16000 +Precision : 16-bit +Duration : 00:00:04.83 = 77200 samples ~ 361.875 CDDA sectors +File Size : 105k +Bit Rate : 174k +Sample Encoding: 16-bit FLAC + +Total Duration of 3 files: 00:00:28.16 +``` + +### (2) Use HLG decoding + +```bash +cd /path/to/icefall/egs/librispeech/ASR + +./conformer_ctc/pretrained.py \ + --checkpoint ./tmp/conformer_ctc/exp/pretraind.pt \ + --words-file ./tmp/conformer_ctc/data/lang_bpe/words.txt \ + --HLG ./tmp/conformer_ctc/data/lang_bpe/HLG.pt \ + ./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac \ + ./tmp/conformer_ctc/test_wavs/1221-135766-0001.flac \ + ./tmp/conformer_ctc/test_wavs/1221-135766-0002.flac +``` + +The output is given below: + +``` +2021-08-20 11:03:05,712 INFO [pretrained.py:217] device: cuda:0 +2021-08-20 11:03:05,712 INFO [pretrained.py:219] Creating model +2021-08-20 11:03:11,345 INFO [pretrained.py:238] Loading HLG from ./tmp/conformer_ctc/data/lang_bpe/HLG.pt +2021-08-20 11:03:18,442 INFO [pretrained.py:255] Constructing Fbank computer +2021-08-20 11:03:18,444 INFO [pretrained.py:265] Reading sound files: ['./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/conformer_ctc/test_wavs/1221-135766-0002.flac'] +2021-08-20 11:03:18,507 INFO [pretrained.py:271] Decoding started +2021-08-20 11:03:18,795 INFO [pretrained.py:300] Use HLG decoding +2021-08-20 11:03:19,149 INFO [pretrained.py:339] +./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac: +AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS + +./tmp/conformer_ctc/test_wavs/1221-135766-0001.flac: +GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONOURED +BOSOM TO CONNECT HER PARENT FOR EVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN + +./tmp/conformer_ctc/test_wavs/1221-135766-0002.flac: +YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION + + +2021-08-20 11:03:19,149 INFO [pretrained.py:341] Decoding Done +``` + +### (3) Use HLG decoding + LM rescoring + +```bash +./conformer_ctc/pretrained.py \ + --checkpoint ./tmp/conformer_ctc/exp/pretraind.pt \ + --words-file ./tmp/conformer_ctc/data/lang_bpe/words.txt \ + --HLG ./tmp/conformer_ctc/data/lang_bpe/HLG.pt \ + --method whole-lattice-rescoring \ + --G ./tmp/conformer_ctc/data/lm/G_4_gram.pt \ + --ngram-lm-scale 0.8 \ + ./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac \ + ./tmp/conformer_ctc/test_wavs/1221-135766-0001.flac \ + ./tmp/conformer_ctc/test_wavs/1221-135766-0002.flac +``` + +The output is: + +``` +2021-08-20 11:12:17,565 INFO [pretrained.py:217] device: cuda:0 +2021-08-20 11:12:17,565 INFO [pretrained.py:219] Creating model +2021-08-20 11:12:23,728 INFO [pretrained.py:238] Loading HLG from ./tmp/conformer_ctc/data/lang_bpe/HLG.pt +2021-08-20 11:12:30,035 INFO [pretrained.py:246] Loading G from ./tmp/conformer_ctc/data/lm/G_4_gram.pt +2021-08-20 11:13:10,779 INFO [pretrained.py:255] Constructing Fbank computer +2021-08-20 11:13:10,787 INFO [pretrained.py:265] Reading sound files: ['./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/conformer_ctc/test_wavs/1221-135766-0002.flac'] +2021-08-20 11:13:10,798 INFO [pretrained.py:271] Decoding started +2021-08-20 11:13:11,085 INFO [pretrained.py:305] Use HLG decoding + LM rescoring +2021-08-20 11:13:11,736 INFO [pretrained.py:339] +./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac: +AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS + +./tmp/conformer_ctc/test_wavs/1221-135766-0001.flac: +GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONOURED +BOSOM TO CONNECT HER PARENT FOR EVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN + +./tmp/conformer_ctc/test_wavs/1221-135766-0002.flac: +YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION + + +2021-08-20 11:13:11,737 INFO [pretrained.py:341] Decoding Done +``` + +### (4) Use HLG decoding + LM rescoring + attention decoder rescoring + +```bash +./conformer_ctc/pretrained.py \ + --checkpoint ./tmp/conformer_ctc/exp/pretraind.pt \ + --words-file ./tmp/conformer_ctc/data/lang_bpe/words.txt \ + --HLG ./tmp/conformer_ctc/data/lang_bpe/HLG.pt \ + --method attention-decoder \ + --G ./tmp/conformer_ctc/data/lm/G_4_gram.pt \ + --ngram-lm-scale 1.3 \ + --attention-decoder-scale 1.2 \ + --lattice-score-scale 0.5 \ + --num-paths 100 \ + --sos-id 1 \ + --eos-id 1 \ + ./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac \ + ./tmp/conformer_ctc/test_wavs/1221-135766-0001.flac \ + ./tmp/conformer_ctc/test_wavs/1221-135766-0002.flac +``` + +The output is: + +``` +2021-08-20 11:19:11,397 INFO [pretrained.py:217] device: cuda:0 +2021-08-20 11:19:11,397 INFO [pretrained.py:219] Creating model +2021-08-20 11:19:17,354 INFO [pretrained.py:238] Loading HLG from ./tmp/conformer_ctc/data/lang_bpe/HLG.pt +2021-08-20 11:19:24,615 INFO [pretrained.py:246] Loading G from ./tmp/conformer_ctc/data/lm/G_4_gram.pt +2021-08-20 11:20:04,576 INFO [pretrained.py:255] Constructing Fbank computer +2021-08-20 11:20:04,584 INFO [pretrained.py:265] Reading sound files: ['./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/conformer_ctc/test_wavs/1221-135766-0002.flac'] +2021-08-20 11:20:04,595 INFO [pretrained.py:271] Decoding started +2021-08-20 11:20:04,854 INFO [pretrained.py:313] Use HLG + LM rescoring + attention decoder rescoring +2021-08-20 11:20:05,805 INFO [pretrained.py:339] +./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac: +AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS + +./tmp/conformer_ctc/test_wavs/1221-135766-0001.flac: +GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONOURED +BOSOM TO CONNECT HER PARENT FOR EVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN + +./tmp/conformer_ctc/test_wavs/1221-135766-0002.flac: +YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION + + +2021-08-20 11:20:05,805 INFO [pretrained.py:341] Decoding Done +``` + +**NOTE**: We provide a colab notebook for demonstration. +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1huyupXAcHsUrKaWfI83iMEJ6J0Nh0213?usp=sharing) + +Due to limited memory provided by Colab, you have to upgrade to Colab Pro to +run `HLG decoding + LM rescoring` and `HLG decoding + LM rescoring + attention decoder rescoring`. +Otherwise, you can only run `HLG decoding` with Colab. diff --git a/egs/librispeech/ASR/conformer_ctc/asr_datamodule.py b/egs/librispeech/ASR/conformer_ctc/asr_datamodule.py new file mode 120000 index 000000000..fa1b8cca3 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/asr_datamodule.py @@ -0,0 +1 @@ +../tdnn_lstm_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 889a0a474..c540b1ea1 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -13,14 +13,15 @@ from typing import Dict, List, Optional, Tuple import k2 import torch import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.dataset.librispeech import LibriSpeechAsrDataModule from icefall.decode import ( get_lattice, nbest_decoding, + nbest_oracle, one_best_decoding, rescore_with_attention_decoder, rescore_with_n_best_list, @@ -56,6 +57,18 @@ def get_parser(): "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", ) + + parser.add_argument( + "--lattice-score-scale", + type=float, + default=1.0, + help="The scale to be applied to `lattice.scores`." + "It's needed if you use any kinds of n-best based rescoring. " + "Currently, it is used when the decoding method is: nbest, " + "nbest-rescoring, attention-decoder, and nbest-oracle. " + "A smaller value results in more unique paths.", + ) + return parser @@ -85,10 +98,14 @@ def get_params() -> AttributeDict: # - nbest-rescoring # - whole-lattice-rescoring # - attention-decoder + # - nbest-oracle + # "method": "nbest", + # "method": "nbest-rescoring", # "method": "whole-lattice-rescoring", "method": "attention-decoder", + # "method": "nbest-oracle", # num_paths is used when method is "nbest", "nbest-rescoring", - # and attention-decoder + # attention-decoder, and nbest-oracle "num_paths": 100, } ) @@ -179,6 +196,19 @@ def decode_one_batch( subsampling_factor=params.subsampling_factor, ) + if params.method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is slightly worse than that of rescored lattices. + return nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + lexicon=lexicon, + scale=params.lattice_score_scale, + ) + if params.method in ["1best", "nbest"]: if params.method == "1best": best_path = one_best_decoding( @@ -190,8 +220,9 @@ def decode_one_batch( lattice=lattice, num_paths=params.num_paths, use_double_scores=params.use_double_scores, + scale=params.lattice_score_scale, ) - key = f"no_rescore-{params.num_paths}" + key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa hyps = get_texts(best_path) hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] @@ -212,6 +243,7 @@ def decode_one_batch( G=G, num_paths=params.num_paths, lm_scale_list=lm_scale_list, + scale=params.lattice_score_scale, ) elif params.method == "whole-lattice-rescoring": best_path_dict = rescore_with_whole_lattice( @@ -231,6 +263,7 @@ def decode_one_batch( memory_key_padding_mask=memory_key_padding_mask, sos_id=sos_id, eos_id=eos_id, + scale=params.lattice_score_scale, ) else: assert False, f"Unsupported decoding method: {params.method}" @@ -284,7 +317,11 @@ def decode_dataset( results = [] num_cuts = 0 - tot_num_cuts = len(dl.dataset.cuts) + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -313,10 +350,10 @@ def decode_dataset( num_cuts += len(batch["supervisions"]["text"]) if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + logging.info( - f"batch {batch_idx}, cuts processed until now is " - f"{num_cuts}/{tot_num_cuts} " - f"({float(num_cuts)/tot_num_cuts*100:.6f}%)" + f"batch {batch_str}, cuts processed until now is {num_cuts}" ) return results @@ -376,7 +413,7 @@ def main(): params = get_params() params.update(vars(args)) - setup_logger(f"{params.exp_dir}/log/log-decode") + setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") logging.info("Decoding started") logging.info(params) @@ -399,7 +436,9 @@ def main(): sos_id = graph_compiler.sos_id eos_id = graph_compiler.eos_id - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt")) + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") + ) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -430,7 +469,7 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu") G = k2.Fsa.from_dict(d).to(device) if params.method in ["whole-lattice-rescoring", "attention-decoder"]: diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py new file mode 100755 index 000000000..c63616d28 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -0,0 +1,350 @@ +#!/usr/bin/env python3 + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from conformer import Conformer +from torch.nn.utils.rnn import pad_sequence + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_attention_decoder, + rescore_with_whole_lattice, +) +from icefall.utils import AttributeDict, get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--words-file", + type=str, + required=True, + help="Path to words.txt", + ) + + parser.add_argument( + "--HLG", type=str, required=True, help="Path to HLG.pt." + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + (3) attention-decoder - Extract n paths from he rescored + lattice and use the transformer attention decoder for + rescoring. + We call it HLG decoding + n-gram LM rescoring + attention + decoder rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or attention-decoder. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=1.3, + help=""" + Used only when method is whole-lattice-rescoring and attention-decoder. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--attention-decoder-scale", + type=float, + default=1.2, + help=""" + Used only when method is attention-decoder. + It specifies the scale for attention decoder scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--lattice-score-scale", + type=float, + default=0.5, + help=""" + Used only when method is attention-decoder. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--sos-id", + type=float, + default=1, + help=""" + Used only when method is attention-decoder. + It specifies ID for the SOS token. + """, + ) + + parser.add_argument( + "--eos-id", + type=float, + default=1, + help=""" + Used only when method is attention-decoder. + It specifies ID for the EOS token. + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "feature_dim": 80, + "nhead": 8, + "num_classes": 5000, + "sample_rate": 16000, + "attention_dim": 512, + "subsampling_factor": 4, + "num_decoder_layers": 6, + "vgg_frontend": False, + "is_espnet_structure": True, + "mmi_loss": False, + "use_feat_batchnorm": True, + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + params.update(vars(args)) + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=params.num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=params.vgg_frontend, + is_espnet_structure=params.is_espnet_structure, + mmi_loss=params.mmi_loss, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"]) + model.to(device) + model.eval() + + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in ["whole-lattice-rescoring", "attention-decoder"]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = G.to(device) + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G.lm_scores = G.scores.clone() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info(f"Decoding started") + features = fbank(waves) + + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) + + # Note: We don't use key padding mask for attention during decoding + with torch.no_grad(): + nnet_output, memory, memory_key_padding_mask = model(features) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], + dtype=torch.int32, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + HLG=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "attention-decoder": + logging.info("Use HLG + LM rescoring + attention decoder rescoring") + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None + ) + best_path_dict = rescore_with_attention_decoder( + lattice=rescored_lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=params.sos_id, + eos_id=params.eos_id, + scale=params.lattice_score_scale, + ngram_lm_scale=params.ngram_lm_scale, + attention_scale=params.attention_decoder_scale, + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info(f"Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 645757ebc..d17ee6164 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -13,6 +13,7 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer from lhotse.utils import fix_random_seed from torch.nn.parallel import DistributedDataParallel as DDP @@ -23,7 +24,6 @@ from transformer import Noam from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dataset.librispeech import LibriSpeechAsrDataModule from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.utils import ( @@ -60,9 +60,6 @@ def get_parser(): help="Should various information be logged in tensorboard.", ) - # TODO: add extra arguments and support DDP training. - # Currently, only single GPU training is implemented. Will add - # DDP training once single GPU training is finished. return parser @@ -127,7 +124,7 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { - "exp_dir": Path("conformer_ctc/exp_new"), + "exp_dir": Path("conformer_ctc/exp"), "lang_dir": Path("data/lang_bpe"), "feature_dim": 80, "weight_decay": 1e-6, @@ -145,7 +142,6 @@ def get_params() -> AttributeDict: "beam_size": 10, "reduction": "sum", "use_double_scores": True, - # "accum_grad": 1, "att_rate": 0.7, "attention_dim": 512, diff --git a/icefall/dataset/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py similarity index 70% rename from icefall/dataset/asr_datamodule.py rename to egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index aae7af9ce..8d8c7a366 100644 --- a/icefall/dataset/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -1,14 +1,16 @@ import argparse import logging +from functools import lru_cache from pathlib import Path from typing import List, Union -from lhotse import Fbank, FbankConfig, load_manifest +from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse.dataset import ( BucketingSampler, CutConcatenate, CutMix, K2SpeechRecognitionDataset, + PrecomputedFeatures, SingleCutSampler, SpecAugment, ) @@ -19,7 +21,7 @@ from icefall.dataset.datamodule import DataModule from icefall.utils import str2bool -class AsrDataModule(DataModule): +class LibriSpeechAsrDataModule(DataModule): """ DataModule for K2 ASR experiments. It assumes there is always one train and valid dataloader, @@ -47,6 +49,13 @@ class AsrDataModule(DataModule): "effective batch sizes, sampling strategies, applied data " "augmentations, etc.", ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", + ) group.add_argument( "--feature-dir", type=Path, @@ -77,7 +86,7 @@ class AsrDataModule(DataModule): group.add_argument( "--concatenate-cuts", type=str2bool, - default=True, + default=False, help="When enabled, utterances (cuts) will be concatenated " "to minimize the amount of padding.", ) @@ -104,6 +113,29 @@ class AsrDataModule(DataModule): "extraction. Will drop existing precomputed feature manifests " "if available.", ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) def train_dataloaders(self) -> DataLoader: logging.info("About to get train cuts") @@ -138,9 +170,9 @@ class AsrDataModule(DataModule): ] train = K2SpeechRecognitionDataset( - cuts_train, cut_transforms=transforms, input_transforms=input_transforms, + return_cuts=self.args.return_cuts, ) if self.args.on_the_fly_feats: @@ -154,14 +186,13 @@ class AsrDataModule(DataModule): # to be strict (e.g. could be randomized) # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa # Drop feats to be on the safe side. - cuts_train = cuts_train.drop_features() train = K2SpeechRecognitionDataset( - cuts=cuts_train, cut_transforms=transforms, input_strategy=OnTheFlyFeatures( Fbank(FbankConfig(num_mel_bins=80)) ), input_transforms=input_transforms, + return_cuts=self.args.return_cuts, ) if self.args.bucketing_sampler: @@ -169,44 +200,60 @@ class AsrDataModule(DataModule): train_sampler = BucketingSampler( cuts_train, max_duration=self.args.max_duration, - shuffle=True, + shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + bucket_method="equal_duration", + drop_last=True, ) else: logging.info("Using SingleCutSampler.") train_sampler = SingleCutSampler( cuts_train, max_duration=self.args.max_duration, - shuffle=True, + shuffle=self.args.shuffle, ) logging.info("About to create train dataloader") + train_dl = DataLoader( train, sampler=train_sampler, batch_size=None, - num_workers=4, - persistent_workers=True, + num_workers=self.args.num_workers, + persistent_workers=False, ) + return train_dl def valid_dataloaders(self) -> DataLoader: logging.info("About to get dev cuts") cuts_valid = self.valid_cuts() + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + logging.info("About to create dev dataset") if self.args.on_the_fly_feats: - cuts_valid = cuts_valid.drop_features() validate = K2SpeechRecognitionDataset( - cuts_valid.drop_features(), + cut_transforms=transforms, input_strategy=OnTheFlyFeatures( Fbank(FbankConfig(num_mel_bins=80)) ), + return_cuts=self.args.return_cuts, ) else: - validate = K2SpeechRecognitionDataset(cuts_valid) + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) valid_sampler = SingleCutSampler( cuts_valid, max_duration=self.args.max_duration, + shuffle=False, ) logging.info("About to create dev dataloader") valid_dl = DataLoader( @@ -214,8 +261,9 @@ class AsrDataModule(DataModule): sampler=valid_sampler, batch_size=None, num_workers=2, - persistent_workers=True, + persistent_workers=False, ) + return valid_dl def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]: @@ -228,10 +276,12 @@ class AsrDataModule(DataModule): for cuts_test in cuts: logging.debug("About to create test dataset") test = K2SpeechRecognitionDataset( - cuts_test, input_strategy=OnTheFlyFeatures( Fbank(FbankConfig(num_mel_bins=80)) - ), + ) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, ) sampler = SingleCutSampler( cuts_test, max_duration=self.args.max_duration @@ -246,3 +296,42 @@ class AsrDataModule(DataModule): return test_loaders else: return test_loaders[0] + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + cuts_train = load_manifest( + self.args.feature_dir / "cuts_train-clean-100.json.gz" + ) + if self.args.full_libri: + cuts_train = ( + cuts_train + + load_manifest( + self.args.feature_dir / "cuts_train-clean-360.json.gz" + ) + + load_manifest( + self.args.feature_dir / "cuts_train-other-500.json.gz" + ) + ) + return cuts_train + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + cuts_valid = load_manifest( + self.args.feature_dir / "cuts_dev-clean.json.gz" + ) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz") + return cuts_valid + + @lru_cache() + def test_cuts(self) -> List[CutSet]: + test_sets = ["test-clean", "test-other"] + cuts = [] + for test_set in test_sets: + logging.debug("About to get test cuts") + cuts.append( + load_manifest( + self.args.feature_dir / f"cuts_{test_set}.json.gz" + ) + ) + return cuts diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 137fa795c..72f39ef40 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -10,10 +10,10 @@ from typing import Dict, List, Optional, Tuple import k2 import torch import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule from model import TdnnLstm from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.dataset.librispeech import LibriSpeechAsrDataModule from icefall.decode import ( get_lattice, nbest_decoding, @@ -236,7 +236,11 @@ def decode_dataset( results = [] num_cuts = 0 - tot_num_cuts = len(dl.dataset.cuts) + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -263,10 +267,10 @@ def decode_dataset( num_cuts += len(batch["supervisions"]["text"]) if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + logging.info( - f"batch {batch_idx}, cuts processed until now is " - f"{num_cuts}/{tot_num_cuts} " - f"({float(num_cuts)/tot_num_cuts*100:.6f}%)" + f"batch {batch_str}, cuts processed until now is {num_cuts}" ) return results @@ -328,7 +332,9 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict(torch.load("data/lang_phone/HLG.pt")) + HLG = k2.Fsa.from_dict( + torch.load("data/lang_phone/HLG.pt", map_location="cpu") + ) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -355,7 +361,7 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu") G = k2.Fsa.from_dict(d).to(device) if params.method == "whole-lattice-rescoring": diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index dbb9f64ec..4adb988a0 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -1,7 +1,5 @@ #!/usr/bin/env python3 -# This is just at the very beginning ... - import argparse import logging from pathlib import Path @@ -14,16 +12,16 @@ import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn import torch.optim as optim +from asr_datamodule import LibriSpeechAsrDataModule from lhotse.utils import fix_random_seed from model import TdnnLstm from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_value_ +from torch.nn.utils import clip_grad_norm_ from torch.optim.lr_scheduler import StepLR from torch.utils.tensorboard import SummaryWriter from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dataset.librispeech import LibriSpeechAsrDataModule from icefall.dist import cleanup_dist, setup_dist from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon @@ -61,9 +59,6 @@ def get_parser(): help="Should various information be logged in tensorboard.", ) - # TODO: add extra arguments and support DDP training. - # Currently, only single GPU training is implemented. Will add - # DDP training once single GPU training is finished. return parser @@ -406,7 +401,7 @@ def train_one_epoch( optimizer.zero_grad() loss.backward() - clip_grad_value_(model.parameters(), 5.0) + clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() loss_cpu = loss.detach().cpu().item() diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index e45df4fe4..a64ecfcf6 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -91,7 +91,7 @@ def load_checkpoint( checkpoint.pop("model") def load(name, obj): - s = checkpoint[name] + s = checkpoint.get(name, None) if obj and s: obj.load_state_dict(s) checkpoint.pop(name) diff --git a/icefall/dataset/librispeech.py b/icefall/dataset/librispeech.py deleted file mode 100644 index 5c18041ed..000000000 --- a/icefall/dataset/librispeech.py +++ /dev/null @@ -1,68 +0,0 @@ -import argparse -import logging -from functools import lru_cache -from typing import List - -from lhotse import CutSet, load_manifest - -from icefall.dataset.asr_datamodule import AsrDataModule -from icefall.utils import str2bool - - -class LibriSpeechAsrDataModule(AsrDataModule): - """ - LibriSpeech ASR data module. Can be used for 100h subset - (``--full-libri false``) or full 960h set. - The train and valid cuts for standard Libri splits are - concatenated into a single CutSet/DataLoader. - """ - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - super().add_arguments(parser) - group = parser.add_argument_group(title="LibriSpeech specific options") - group.add_argument( - "--full-libri", - type=str2bool, - default=True, - help="When enabled, use 960h LibriSpeech.", - ) - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - cuts_train = load_manifest( - self.args.feature_dir / "cuts_train-clean-100.json.gz" - ) - if self.args.full_libri: - cuts_train = ( - cuts_train - + load_manifest( - self.args.feature_dir / "cuts_train-clean-360.json.gz" - ) - + load_manifest( - self.args.feature_dir / "cuts_train-other-500.json.gz" - ) - ) - return cuts_train - - @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - cuts_valid = load_manifest( - self.args.feature_dir / "cuts_dev-clean.json.gz" - ) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz") - return cuts_valid - - @lru_cache() - def test_cuts(self) -> List[CutSet]: - test_sets = ["test-clean", "test-other"] - cuts = [] - for test_set in test_sets: - logging.debug("About to get test cuts") - cuts.append( - load_manifest( - self.args.feature_dir / f"cuts_{test_set}.json.gz" - ) - ) - return cuts diff --git a/icefall/decode.py b/icefall/decode.py index 0e9baf2e4..49d642f1c 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -2,9 +2,42 @@ import logging from typing import Dict, List, Optional, Tuple, Union import k2 +import kaldialign import torch import torch.nn as nn +from icefall.lexicon import Lexicon + + +def _get_random_paths( + lattice: k2.Fsa, + num_paths: int, + use_double_scores: bool = True, + scale: float = 1.0, +): + """ + Args: + lattice: + The decoding lattice, returned by :func:`get_lattice`. + num_paths: + It specifies the size `n` in n-best. Note: Paths are selected randomly + and those containing identical word sequences are remove dand only one + of them is kept. + use_double_scores: + True to use double precision floating point in the computation. + False to use single precision. + scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + Returns: + Return a k2.RaggedInt with 3 axes [seq][path][arc_pos] + """ + saved_scores = lattice.scores.clone() + lattice.scores *= scale + path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True) + lattice.scores = saved_scores + return path + def _intersect_device( a_fsas: k2.Fsa, @@ -129,7 +162,10 @@ def one_best_decoding( def nbest_decoding( - lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True + lattice: k2.Fsa, + num_paths: int, + use_double_scores: bool = True, + scale: float = 1.0, ) -> k2.Fsa: """It implements something like CTC prefix beam search using n-best lists. @@ -152,12 +188,18 @@ def nbest_decoding( use_double_scores: True to use double precision floating point in the computation. False to use single precision. + scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. Returns: An FsaVec containing linear FSAs. """ - # First, extract `num_paths` paths for each sequence. - # path is a k2.RaggedInt with axes [seq][path][arc_pos] - path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True) + path = _get_random_paths( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + scale=scale, + ) # word_seq is a k2.RaggedInt sharing the same shape as `path` # but it contains word IDs. Note that it also contains 0s and -1s. @@ -320,7 +362,11 @@ def compute_am_and_lm_scores( def rescore_with_n_best_list( - lattice: k2.Fsa, G: k2.Fsa, num_paths: int, lm_scale_list: List[float] + lattice: k2.Fsa, + G: k2.Fsa, + num_paths: int, + lm_scale_list: List[float], + scale: float = 1.0, ) -> Dict[str, k2.Fsa]: """Decode using n-best list with LM rescoring. @@ -342,6 +388,9 @@ def rescore_with_n_best_list( It is the size `n` in `n-best` list. lm_scale_list: A list containing lm_scale values. + scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. Returns: A dict of FsaVec, whose key is an lm_scale and the value is the best decoding path for each sequence in the lattice. @@ -356,9 +405,12 @@ def rescore_with_n_best_list( assert G.device == device assert hasattr(G, "aux_labels") is False - # First, extract `num_paths` paths for each sequence. - # path is a k2.RaggedInt with axes [seq][path][arc_pos] - path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True) + path = _get_random_paths( + lattice=lattice, + num_paths=num_paths, + use_double_scores=True, + scale=scale, + ) # word_seq is a k2.RaggedInt sharing the same shape as `path` # but it contains word IDs. Note that it also contains 0s and -1s. @@ -376,7 +428,7 @@ def rescore_with_n_best_list( # # num_repeats is also a k2.RaggedInt with 2 axes containing the # multiplicities of each path. - # num_repeats.num_elements() == unique_word_seqs.num_elements() + # num_repeats.num_elements() == unique_word_seqs.tot_size(1) # # Since k2.ragged.unique_sequences will reorder paths within a seq, # `new2old` is a 1-D torch.Tensor mapping from the output path index @@ -494,6 +546,8 @@ def rescore_with_whole_lattice( del lattice.lm_scores assert hasattr(lattice, "lm_scores") is False + assert hasattr(G_with_epsilon_loops, "lm_scores") + # Now, lattice.scores contains only am_scores # inv_lattice has word IDs as labels. @@ -549,14 +603,88 @@ def rescore_with_whole_lattice( return ans +def nbest_oracle( + lattice: k2.Fsa, + num_paths: int, + ref_texts: List[str], + lexicon: Lexicon, + scale: float = 1.0, +) -> Dict[str, List[List[int]]]: + """Select the best hypothesis given a lattice and a reference transcript. + + The basic idea is to extract n paths from the given lattice, unique them, + and select the one that has the minimum edit distance with the corresponding + reference transcript as the decoding output. + + The decoding result returned from this function is the best result that + we can obtain using n-best decoding with all kinds of rescoring techniques. + + Args: + lattice: + An FsaVec. It can be the return value of :func:`get_lattice`. + Note: We assume its aux_labels contain word IDs. + num_paths: + The size of `n` in n-best. + ref_texts: + A list of reference transcript. Each entry contains space(s) + separated words + lexicon: + It is used to convert word IDs to word symbols. + scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + Return: + Return a dict. Its key contains the information about the parameters + when calling this function, while its value contains the decoding output. + `len(ans_dict) == len(ref_texts)` + """ + path = _get_random_paths( + lattice=lattice, + num_paths=num_paths, + use_double_scores=True, + scale=scale, + ) + + word_seq = k2.index(lattice.aux_labels, path) + word_seq = k2.ragged.remove_values_leq(word_seq, 0) + unique_word_seq, _, _ = k2.ragged.unique_sequences( + word_seq, need_num_repeats=False, need_new2old_indexes=False + ) + unique_word_ids = k2.ragged.to_list(unique_word_seq) + assert len(unique_word_ids) == len(ref_texts) + # unique_word_ids[i] contains all hypotheses of the i-th utterance + + results = [] + for hyps, ref in zip(unique_word_ids, ref_texts): + # Note hyps is a list-of-list ints + # Each sublist contains a hypothesis + ref_words = ref.strip().split() + # CAUTION: We don't convert ref_words to ref_words_ids + # since there may exist OOV words in ref_words + best_hyp_words = None + min_error = float("inf") + for hyp_words in hyps: + hyp_words = [lexicon.word_table[i] for i in hyp_words] + this_error = kaldialign.edit_distance(ref_words, hyp_words)["total"] + if this_error < min_error: + min_error = this_error + best_hyp_words = hyp_words + results.append(best_hyp_words) + + return {f"nbest_{num_paths}_scale_{scale}_oracle": results} + + def rescore_with_attention_decoder( lattice: k2.Fsa, num_paths: int, model: nn.Module, memory: torch.Tensor, - memory_key_padding_mask: torch.Tensor, + memory_key_padding_mask: Optional[torch.Tensor], sos_id: int, eos_id: int, + scale: float = 1.0, + ngram_lm_scale: Optional[float] = None, + attention_scale: Optional[float] = None, ) -> Dict[str, k2.Fsa]: """This function extracts n paths from the given lattice and uses an attention decoder to rescore them. The path with the highest @@ -580,6 +708,13 @@ def rescore_with_attention_decoder( The token ID for SOS. eos_id: The token ID for EOS. + scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + ngram_lm_scale: + Optional. It specifies the scale for n-gram LM scores. + attention_scale: + Optional. It specifies the scale for attention decoder scores. Returns: A dict of FsaVec, whose key contains a string ngram_lm_scale_attention_scale and the value is the @@ -587,7 +722,12 @@ def rescore_with_attention_decoder( """ # First, extract `num_paths` paths for each sequence. # path is a k2.RaggedInt with axes [seq][path][arc_pos] - path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True) + path = _get_random_paths( + lattice=lattice, + num_paths=num_paths, + use_double_scores=True, + scale=scale, + ) # word_seq is a k2.RaggedInt sharing the same shape as `path` # but it contains word IDs. Note that it also contains 0s and -1s. @@ -605,12 +745,12 @@ def rescore_with_attention_decoder( # # num_repeats is also a k2.RaggedInt with 2 axes containing the # multiplicities of each path. - # num_repeats.num_elements() == unique_word_seqs.num_elements() + # num_repeats.num_elements() == unique_word_seqs.tot_size(1) # # Since k2.ragged.unique_sequences will reorder paths within a seq, # `new2old` is a 1-D torch.Tensor mapping from the output path index # to the input path index. - # new2old.numel() == unique_word_seqs.tot_size(1) + # new2old.numel() == unique_word_seq.tot_size(1) unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences( word_seq, need_num_repeats=True, need_new2old_indexes=True ) @@ -662,11 +802,13 @@ def rescore_with_attention_decoder( path_to_seq_map_long = path_to_seq_map.to(torch.long) expanded_memory = memory.index_select(1, path_to_seq_map_long) - expanded_memory_key_padding_mask = memory_key_padding_mask.index_select( - 0, path_to_seq_map_long - ) + if memory_key_padding_mask is not None: + expanded_memory_key_padding_mask = memory_key_padding_mask.index_select( + 0, path_to_seq_map_long + ) + else: + expanded_memory_key_padding_mask = None - # TODO: pass the sos_token_id and eos_token_id via function arguments nll = model.decoder_nll( memory=expanded_memory, memory_key_padding_mask=expanded_memory_key_padding_mask, @@ -681,11 +823,17 @@ def rescore_with_attention_decoder( assert attention_scores.ndim == 1 assert attention_scores.numel() == num_word_seqs - ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] - ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + if ngram_lm_scale is None: + ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + else: + ngram_lm_scale_list = [ngram_lm_scale] - attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] - attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + if attention_scale is None: + attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + else: + attention_scale_list = [attention_scale] path_2axes = k2.ragged.remove_axis(path, 0)