From eed3fc561024ab6754134032e0286c9da02d4c1b Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Wed, 25 Aug 2021 17:48:34 +0800 Subject: [PATCH 01/12] Correct some spelling mistakes (#28) * Update index.rst (AS->ASR) * Update conformer_ctc.rst (pretraind->pretrained) --- docs/source/installation/index.rst | 2 +- docs/source/recipes/librispeech/conformer_ctc.rst | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/installation/index.rst b/docs/source/installation/index.rst index e8dc9b461..bcef669c8 100644 --- a/docs/source/installation/index.rst +++ b/docs/source/installation/index.rst @@ -308,7 +308,7 @@ Data preparation $ export PYTHONPATH=/tmp/icefall:$PYTHONPATH $ cd /tmp/icefall - $ cd egs/yesno/AS + $ cd egs/yesno/ASR $ ./prepare.sh The log of running ``./prepare.sh`` is: diff --git a/docs/source/recipes/librispeech/conformer_ctc.rst b/docs/source/recipes/librispeech/conformer_ctc.rst index 2cb04d1ba..50f262a54 100644 --- a/docs/source/recipes/librispeech/conformer_ctc.rst +++ b/docs/source/recipes/librispeech/conformer_ctc.rst @@ -367,7 +367,7 @@ After downloading, you will have the following files: | `-- lm | `-- G_4_gram.pt |-- exp - | `-- pretraind.pt + | `-- pretrained.pt `-- test_wavs |-- 1089-134686-0001.flac |-- 1221-135766-0001.flac @@ -475,7 +475,7 @@ The command to run HLG decoding is: $ cd egs/librispeech/ASR $ ./conformer_ctc/pretrained.py \ - --checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretraind.pt \ + --checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \ --words-file ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/words.txt \ --HLG ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt \ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac \ @@ -518,7 +518,7 @@ The command to run HLG decoding + LM rescoring is: $ cd egs/librispeech/ASR $ ./conformer_ctc/pretrained.py \ - --checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretraind.pt \ + --checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \ --words-file ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/words.txt \ --HLG ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt \ --method whole-lattice-rescoring \ @@ -566,7 +566,7 @@ The command to run HLG decoding + LM rescoring + attention decoder rescoring is: $ cd egs/librispeech/ASR $ ./conformer_ctc/pretrained.py \ - --checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretraind.pt \ + --checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \ --words-file ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/words.txt \ --HLG ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt \ --method attention-decoder \ From 5baa6a9f1c4ef977b10c0e3adace5d8c455c935a Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Wed, 25 Aug 2021 21:41:46 +0800 Subject: [PATCH 02/12] fix a spelling mistake (tourch->touch) (#29) --- docs/source/contributing/how-to-create-a-recipe.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/contributing/how-to-create-a-recipe.rst b/docs/source/contributing/how-to-create-a-recipe.rst index 2d53fd89f..a30fb9056 100644 --- a/docs/source/contributing/how-to-create-a-recipe.rst +++ b/docs/source/contributing/how-to-create-a-recipe.rst @@ -56,7 +56,7 @@ organize your files in the following way: $ cd egs/foo/ASR $ mkdir bar $ cd bar - $ tourch README.md model.py train.py decode.py asr_datamodule.py pretrained.py + $ touch README.md model.py train.py decode.py asr_datamodule.py pretrained.py For instance , the ``yesno`` recipe has a ``tdnn`` model and its directory structure looks like the following: From 331e5eb7ab412a2ddc6456d89ece92c9d17307ea Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 2 Sep 2021 07:12:37 +0800 Subject: [PATCH 03/12] [doc] Fix typos. (#31) --- docs/source/recipes/librispeech/conformer_ctc.rst | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/recipes/librispeech/conformer_ctc.rst b/docs/source/recipes/librispeech/conformer_ctc.rst index 50f262a54..af3e59e68 100644 --- a/docs/source/recipes/librispeech/conformer_ctc.rst +++ b/docs/source/recipes/librispeech/conformer_ctc.rst @@ -303,7 +303,7 @@ The commonly used options are: - ``--lattice-score-scale`` - It is used to scaled down lattice scores so that we can more unique + It is used to scale down lattice scores so that there are more unique paths for rescoring. - ``--max-duration`` @@ -314,7 +314,7 @@ The commonly used options are: Pre-trained Model ----------------- -We have uploaded the pre-trained model to +We have uploaded a pre-trained model to ``_. We describe how to use the pre-trained model to transcribe a sound file or @@ -324,7 +324,7 @@ Install kaldifeat ~~~~~~~~~~~~~~~~~ `kaldifeat `_ is used to -extract features for a single sound file or multiple soundfiles +extract features for a single sound file or multiple sound files at the same time. Please refer to ``_ for installation. @@ -397,7 +397,7 @@ After downloading, you will have the following files: - ``data/lm/G_4_gram.pt`` - It is a 4-gram LM, useful for LM rescoring. + It is a 4-gram LM, used for n-gram LM rescoring. - ``exp/pretrained.pt`` @@ -556,7 +556,7 @@ Its output is: HLG decoding + LM rescoring + attention decoder rescoring ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -It uses an n-gram LM to rescore the decoding lattice, extracts +It uses an n-gram LM to rescore the decoding lattice, extracts n paths from the rescored lattice, recores the extracted paths with an attention decoder. The path with the highest score is the decoding result. From abadc714157b81770dcc9b65801d9c87663a7507 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 8 Sep 2021 14:55:30 +0800 Subject: [PATCH 04/12] Use new APIs with k2.RaggedTensor (#38) * Use new APIs with k2.RaggedTensor * Fix style issues. * Update the installation doc, saying it requires at least k2 v1.7 * Use k2 v1.7 --- .github/workflows/run-yesno-recipe.yml | 2 +- .github/workflows/test.yml | 3 +- .gitignore | 2 +- docs/source/conf.py | 1 - .../images/device-CPU_CUDA-orange.svg | 2 +- docs/source/installation/images/k2-v-1.7.svg | 1 + .../images/os-Linux_macOS-ff69b4.svg | 2 +- .../images/python-3.6_3.7_3.8_3.9-blue.svg | 2 +- ....0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg | 2 +- docs/source/installation/index.rst | 23 ++- docs/source/recipes/index.rst | 1 - .../recipes/librispeech/tdnn_lstm_ctc.rst | 18 +- egs/librispeech/ASR/RESULTS.md | 1 - egs/librispeech/ASR/conformer_ctc/decode.py | 19 ++ .../ASR/conformer_ctc/test_subsampling.py | 3 +- .../ASR/conformer_ctc/test_transformer.py | 9 +- egs/librispeech/ASR/local/compile_hlg.py | 6 +- egs/librispeech/ASR/tdnn_lstm_ctc/decode.py | 5 +- .../ASR/tdnn_lstm_ctc/pretrained.py | 0 egs/yesno/ASR/local/compile_hlg.py | 6 +- egs/yesno/ASR/tdnn/decode.py | 1 + icefall/decode.py | 191 ++++++++++-------- icefall/lexicon.py | 11 +- icefall/utils.py | 26 ++- test/test_bpe_graph_compiler.py | 3 +- test/test_utils.py | 4 +- 26 files changed, 197 insertions(+), 147 deletions(-) create mode 100644 docs/source/installation/images/k2-v-1.7.svg mode change 100644 => 100755 egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml index 39a6a0e80..b4e266672 100644 --- a/.github/workflows/run-yesno-recipe.yml +++ b/.github/workflows/run-yesno-recipe.yml @@ -56,7 +56,7 @@ jobs: run: | python3 -m pip install --upgrade pip black flake8 python3 -m pip install -U pip - python3 -m pip install k2==1.4.dev20210822+cpu.torch1.7.1 -f https://k2-fsa.org/nightly/ + python3 -m pip install k2==1.7.dev20210908+cpu.torch1.7.1 -f https://k2-fsa.org/nightly/ python3 -m pip install torchaudio==0.7.2 python3 -m pip install git+https://github.com/lhotse-speech/lhotse diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9110e7db4..c853e3de1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -32,7 +32,8 @@ jobs: os: [ubuntu-18.04, macos-10.15] python-version: [3.6, 3.7, 3.8, 3.9] torch: ["1.8.1"] - k2-version: ["1.4.dev20210822"] + k2-version: ["1.7.dev20210908"] + fail-fast: false steps: diff --git a/.gitignore b/.gitignore index 839a1c34a..e6c84ca5e 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,4 @@ path.sh exp exp*/ *.pt -download/ +download diff --git a/docs/source/conf.py b/docs/source/conf.py index f97f72d54..599df8b3e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -16,7 +16,6 @@ import sphinx_rtd_theme - # -- Project information ----------------------------------------------------- project = "icefall" diff --git a/docs/source/installation/images/device-CPU_CUDA-orange.svg b/docs/source/installation/images/device-CPU_CUDA-orange.svg index b760102e3..a023a1283 100644 --- a/docs/source/installation/images/device-CPU_CUDA-orange.svg +++ b/docs/source/installation/images/device-CPU_CUDA-orange.svg @@ -1 +1 @@ -device: CPU | CUDAdeviceCPU | CUDA \ No newline at end of file +device: CPU | CUDAdeviceCPU | CUDA diff --git a/docs/source/installation/images/k2-v-1.7.svg b/docs/source/installation/images/k2-v-1.7.svg new file mode 100644 index 000000000..8a74d0b55 --- /dev/null +++ b/docs/source/installation/images/k2-v-1.7.svg @@ -0,0 +1 @@ +k2: >= v1.7k2>= v1.7 diff --git a/docs/source/installation/images/os-Linux_macOS-ff69b4.svg b/docs/source/installation/images/os-Linux_macOS-ff69b4.svg index 44c112747..178813ed4 100644 --- a/docs/source/installation/images/os-Linux_macOS-ff69b4.svg +++ b/docs/source/installation/images/os-Linux_macOS-ff69b4.svg @@ -1 +1 @@ -os: Linux | macOSosLinux | macOS \ No newline at end of file +os: Linux | macOSosLinux | macOS diff --git a/docs/source/installation/images/python-3.6_3.7_3.8_3.9-blue.svg b/docs/source/installation/images/python-3.6_3.7_3.8_3.9-blue.svg index 676feba2c..befc1e19e 100644 --- a/docs/source/installation/images/python-3.6_3.7_3.8_3.9-blue.svg +++ b/docs/source/installation/images/python-3.6_3.7_3.8_3.9-blue.svg @@ -1 +1 @@ -python: 3.6 | 3.7 | 3.8 | 3.9python3.6 | 3.7 | 3.8 | 3.9 \ No newline at end of file +python: 3.6 | 3.7 | 3.8 | 3.9python3.6 | 3.7 | 3.8 | 3.9 diff --git a/docs/source/installation/images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg b/docs/source/installation/images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg index d9b0efe1a..496e5a9ef 100644 --- a/docs/source/installation/images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg +++ b/docs/source/installation/images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg @@ -1 +1 @@ -torch: 1.6.0 | 1.7.0 | 1.7.1 | 1.8.0 | 1.8.1 | 1.9.0torch1.6.0 | 1.7.0 | 1.7.1 | 1.8.0 | 1.8.1 | 1.9.0 \ No newline at end of file +torch: 1.6.0 | 1.7.0 | 1.7.1 | 1.8.0 | 1.8.1 | 1.9.0torch1.6.0 | 1.7.0 | 1.7.1 | 1.8.0 | 1.8.1 | 1.9.0 diff --git a/docs/source/installation/index.rst b/docs/source/installation/index.rst index bcef669c8..c11cbd1be 100644 --- a/docs/source/installation/index.rst +++ b/docs/source/installation/index.rst @@ -7,6 +7,7 @@ Installation - |device| - |python_versions| - |torch_versions| +- |k2_versions| .. |os| image:: ./images/os-Linux_macOS-ff69b4.svg :alt: Supported operating systems @@ -20,7 +21,10 @@ Installation .. |torch_versions| image:: ./images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg :alt: Supported PyTorch versions -icefall depends on `k2 `_ and +.. |k2_versions| image:: ./images/k2-v-1.7.svg + :alt: Supported k2 versions + +``icefall`` depends on `k2 `_ and `lhotse `_. We recommend you to install ``k2`` first, as ``k2`` is bound to @@ -32,12 +36,16 @@ installs its dependency PyTorch, which can be reused by ``lhotse``. -------------- Please refer to ``_ -to install `k2`. +to install ``k2``. + +.. CAUTION:: + + You need to install ``k2`` with a version at least **v1.7**. .. HINT:: If you have already installed PyTorch and don't want to replace it, - please install a version of k2 that is compiled against the version + please install a version of ``k2`` that is compiled against the version of PyTorch you are using. (2) Install lhotse @@ -50,10 +58,15 @@ to install ``lhotse``. Install ``lhotse`` also installs its dependency `torchaudio `_. +.. CAUTION:: + + If you have installed ``torchaudio``, please consider uninstalling it before + installing ``lhotse``. Otherwise, it may update your already installed PyTorch. + (3) Download icefall -------------------- -icefall is a collection of Python scripts, so you don't need to install it +``icefall`` is a collection of Python scripts, so you don't need to install it and we don't provide a ``setup.py`` to install it. What you need is to download it and set the environment variable ``PYTHONPATH`` @@ -367,7 +380,7 @@ Now let us run the training part: .. CAUTION:: - We use ``export CUDA_VISIBLE_DEVICES=""`` so that icefall uses CPU + We use ``export CUDA_VISIBLE_DEVICES=""`` so that ``icefall`` uses CPU even if there are GPUs available. The training log is given below: diff --git a/docs/source/recipes/index.rst b/docs/source/recipes/index.rst index db34fdca5..36f8dfc39 100644 --- a/docs/source/recipes/index.rst +++ b/docs/source/recipes/index.rst @@ -15,4 +15,3 @@ We may add recipes for other tasks as well in the future. yesno librispeech - diff --git a/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst b/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst index a59f34db7..64f0a6a08 100644 --- a/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst +++ b/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst @@ -209,7 +209,7 @@ After downloading, you will have the following files: |-- 1221-135766-0001.flac |-- 1221-135766-0002.flac `-- trans.txt - + 6 directories, 10 files @@ -256,14 +256,14 @@ The output is: 2021-08-24 16:57:28,098 INFO [pretrained.py:266] ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac: AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS - + ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac: GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN - + ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac: YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION - - + + 2021-08-24 16:57:28,099 INFO [pretrained.py:268] Decoding Done @@ -297,14 +297,14 @@ The decoding output is: 2021-08-24 16:39:54,010 INFO [pretrained.py:266] ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac: AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS - + ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac: GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN - + ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac: YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION - - + + 2021-08-24 16:39:54,010 INFO [pretrained.py:268] Decoding Done diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index dfc412672..d4acf9206 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -43,4 +43,3 @@ We searched the lm_score_scale for best results, the scales that produced the WE |--|--| |test-clean|0.8| |test-other|0.9| - diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index ff6374d73..cfdcff756 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -45,6 +45,7 @@ from icefall.utils import ( get_texts, setup_logger, store_transcripts, + str2bool, write_error_stats, ) @@ -116,6 +117,17 @@ def get_parser(): """, ) + parser.add_argument( + "--export", + type=str2bool, + default=False, + help="""When enabled, the averaged model is saved to + conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved. + pretrained.pt contains a dict {"model": model.state_dict()}, + which can be loaded by `icefall.checkpoint.load_checkpoint()`. + """, + ) + return parser @@ -541,6 +553,13 @@ def main(): logging.info(f"averaging {filenames}") model.load_state_dict(average_checkpoints(filenames)) + if params.export: + logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") + torch.save( + {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" + ) + return + model.to(device) model.eval() num_param = sum([p.numel() for p in model.parameters()]) diff --git a/egs/librispeech/ASR/conformer_ctc/test_subsampling.py b/egs/librispeech/ASR/conformer_ctc/test_subsampling.py index e3361d0c9..81fa234dd 100755 --- a/egs/librispeech/ASR/conformer_ctc/test_subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/test_subsampling.py @@ -16,9 +16,8 @@ # limitations under the License. -from subsampling import Conv2dSubsampling -from subsampling import VggSubsampling import torch +from subsampling import Conv2dSubsampling, VggSubsampling def test_conv2d_subsampling(): diff --git a/egs/librispeech/ASR/conformer_ctc/test_transformer.py b/egs/librispeech/ASR/conformer_ctc/test_transformer.py index b90215274..667057c51 100644 --- a/egs/librispeech/ASR/conformer_ctc/test_transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/test_transformer.py @@ -17,17 +17,16 @@ import torch +from torch.nn.utils.rnn import pad_sequence from transformer import ( Transformer, + add_eos, + add_sos, + decoder_padding_mask, encoder_padding_mask, generate_square_subsequent_mask, - decoder_padding_mask, - add_sos, - add_eos, ) -from torch.nn.utils.rnn import pad_sequence - def test_encoder_padding_mask(): supervisions = { diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index 19a1ddd23..407fb7d88 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -102,14 +102,14 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: LG.labels[LG.labels >= first_token_disambig_id] = 0 - assert isinstance(LG.aux_labels, k2.RaggedInt) - LG.aux_labels.values()[LG.aux_labels.values() >= first_word_disambig_id] = 0 + assert isinstance(LG.aux_labels, k2.RaggedTensor) + LG.aux_labels.data[LG.aux_labels.data >= first_word_disambig_id] = 0 LG = k2.remove_epsilon(LG) logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") LG = k2.connect(LG) - LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0) + LG.aux_labels = LG.aux_labels.remove_values_eq(0) logging.info("Arc sorting LG") LG = k2.arc_sort(LG) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index afdebd12b..87e9cddb4 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -99,8 +99,10 @@ def get_params() -> AttributeDict: # - nbest-rescoring # - whole-lattice-rescoring "method": "whole-lattice-rescoring", + # "method": "1best", + # "method": "nbest", # num_paths is used when method is "nbest" and "nbest-rescoring" - "num_paths": 30, + "num_paths": 100, } ) return params @@ -424,6 +426,7 @@ def main(): torch.save( {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" ) + return model.to(device) model.eval() diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py old mode 100644 new mode 100755 diff --git a/egs/yesno/ASR/local/compile_hlg.py b/egs/yesno/ASR/local/compile_hlg.py index f2fafd013..41a927455 100755 --- a/egs/yesno/ASR/local/compile_hlg.py +++ b/egs/yesno/ASR/local/compile_hlg.py @@ -80,14 +80,14 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: LG.labels[LG.labels >= first_token_disambig_id] = 0 - assert isinstance(LG.aux_labels, k2.RaggedInt) - LG.aux_labels.values()[LG.aux_labels.values() >= first_word_disambig_id] = 0 + assert isinstance(LG.aux_labels, k2.RaggedTensor) + LG.aux_labels.data[LG.aux_labels.data >= first_word_disambig_id] = 0 LG = k2.remove_epsilon(LG) logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") LG = k2.connect(LG) - LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0) + LG.aux_labels = LG.aux_labels.remove_values_eq(0) logging.info("Arc sorting LG") LG = k2.arc_sort(LG) diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py index aa7b07b98..54fdbb3cc 100755 --- a/egs/yesno/ASR/tdnn/decode.py +++ b/egs/yesno/ASR/tdnn/decode.py @@ -296,6 +296,7 @@ def main(): torch.save( {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" ) + return model.to(device) model.eval() diff --git a/icefall/decode.py b/icefall/decode.py index de3219401..3f6e5fc84 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -84,8 +84,8 @@ def _intersect_device( for start, end in splits: indexes = torch.arange(start, end).to(b_to_a_map) - fsas = k2.index(b_fsas, indexes) - b_to_a = k2.index(b_to_a_map, indexes) + fsas = k2.index_fsa(b_fsas, indexes) + b_to_a = k2.index_select(b_to_a_map, indexes) path_lattice = k2.intersect_device( a_fsas, fsas, b_to_a_map=b_to_a, sorted_match_a=sorted_match_a ) @@ -215,18 +215,16 @@ def nbest_decoding( scale=scale, ) - # word_seq is a k2.RaggedInt sharing the same shape as `path` + # word_seq is a k2.RaggedTensor sharing the same shape as `path` # but it contains word IDs. Note that it also contains 0s and -1s. # The last entry in each sublist is -1. - word_seq = k2.index(lattice.aux_labels, path) - # Note: the above operation supports also the case when - # lattice.aux_labels is a ragged tensor. In that case, - # `remove_axis=True` is used inside the pybind11 binding code, - # so the resulting `word_seq` still has 3 axes, like `path`. - # The 3 axes are [seq][path][word_id] + if isinstance(lattice.aux_labels, torch.Tensor): + word_seq = k2.ragged.index(lattice.aux_labels, path) + else: + word_seq = lattice.aux_labels.index(path, remove_axis=True) # Remove 0 (epsilon) and -1 from word_seq - word_seq = k2.ragged.remove_values_leq(word_seq, 0) + word_seq = word_seq.remove_values_leq(0) # Remove sequences with identical word sequences. # @@ -234,12 +232,12 @@ def nbest_decoding( # `new2old` is a 1-D torch.Tensor mapping from the output path index # to the input path index. # new2old.numel() == unique_word_seqs.tot_size(1) - unique_word_seq, _, new2old = k2.ragged.unique_sequences( - word_seq, need_num_repeats=False, need_new2old_indexes=True + unique_word_seq, _, new2old = word_seq.unique( + need_num_repeats=False, need_new2old_indexes=True ) # Note: unique_word_seq still has the same axes as word_seq - seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0) + seq_to_path_shape = unique_word_seq.shape.get_layer(0) # path_to_seq_map is a 1-D torch.Tensor. # path_to_seq_map[i] is the seq to which the i-th path belongs @@ -247,7 +245,7 @@ def nbest_decoding( # Remove the seq axis. # Now unique_word_seq has only two axes [path][word] - unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0) + unique_word_seq = unique_word_seq.remove_axis(0) # word_fsa is an FsaVec with axes [path][state][arc] word_fsa = k2.linear_fsa(unique_word_seq) @@ -275,35 +273,35 @@ def nbest_decoding( use_double_scores=use_double_scores, log_semiring=False ) - # RaggedFloat currently supports float32 only. - # If Ragged is wrapped, we can use k2.RaggedDouble here - ragged_tot_scores = k2.RaggedFloat( - seq_to_path_shape, tot_scores.to(torch.float32) - ) + ragged_tot_scores = k2.RaggedTensor(seq_to_path_shape, tot_scores) - argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) + argmax_indexes = ragged_tot_scores.argmax() # Since we invoked `k2.ragged.unique_sequences`, which reorders # the index from `path`, we use `new2old` here to convert argmax_indexes # to the indexes into `path`. # # Use k2.index here since argmax_indexes' dtype is torch.int32 - best_path_indexes = k2.index(new2old, argmax_indexes) + best_path_indexes = k2.index_select(new2old, argmax_indexes) - path_2axes = k2.ragged.remove_axis(path, 0) + path_2axes = path.remove_axis(0) - # best_path is a k2.RaggedInt with 2 axes [path][arc_pos] - best_path = k2.index(path_2axes, best_path_indexes) + # best_path is a k2.RaggedTensor with 2 axes [path][arc_pos] + best_path, _ = path_2axes.index( + indexes=best_path_indexes, axis=0, need_value_indexes=False + ) - # labels is a k2.RaggedInt with 2 axes [path][token_id] + # labels is a k2.RaggedTensor with 2 axes [path][token_id] # Note that it contains -1s. - labels = k2.index(lattice.labels.contiguous(), best_path) + labels = k2.ragged.index(lattice.labels.contiguous(), best_path) - labels = k2.ragged.remove_values_eq(labels, -1) + labels = labels.remove_values_eq(-1) - # lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so - # aux_labels is also a k2.RaggedInt with 2 axes - aux_labels = k2.index(lattice.aux_labels, best_path.values()) + # lattice.aux_labels is a k2.RaggedTensor with 2 axes, so + # aux_labels is also a k2.RaggedTensor with 2 axes + aux_labels, _ = lattice.aux_labels.index( + indexes=best_path.data, axis=0, need_value_indexes=False + ) best_path_fsa = k2.linear_fsa(labels) best_path_fsa.aux_labels = aux_labels @@ -426,33 +424,36 @@ def rescore_with_n_best_list( scale=scale, ) - # word_seq is a k2.RaggedInt sharing the same shape as `path` + # word_seq is a k2.RaggedTensor sharing the same shape as `path` # but it contains word IDs. Note that it also contains 0s and -1s. # The last entry in each sublist is -1. - word_seq = k2.index(lattice.aux_labels, path) + if isinstance(lattice.aux_labels, torch.Tensor): + word_seq = k2.ragged.index(lattice.aux_labels, path) + else: + word_seq = lattice.aux_labels.index(path, remove_axis=True) # Remove epsilons and -1 from word_seq - word_seq = k2.ragged.remove_values_leq(word_seq, 0) + word_seq = word_seq.remove_values_leq(0) # Remove paths that has identical word sequences. # - # unique_word_seq is still a k2.RaggedInt with 3 axes [seq][path][word] + # unique_word_seq is still a k2.RaggedTensor with 3 axes [seq][path][word] # except that there are no repeated paths with the same word_seq # within a sequence. # - # num_repeats is also a k2.RaggedInt with 2 axes containing the + # num_repeats is also a k2.RaggedTensor with 2 axes containing the # multiplicities of each path. - # num_repeats.num_elements() == unique_word_seqs.tot_size(1) + # num_repeats.numel() == unique_word_seqs.tot_size(1) # # Since k2.ragged.unique_sequences will reorder paths within a seq, # `new2old` is a 1-D torch.Tensor mapping from the output path index # to the input path index. # new2old.numel() == unique_word_seqs.tot_size(1) - unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences( - word_seq, need_num_repeats=True, need_new2old_indexes=True + unique_word_seq, num_repeats, new2old = word_seq.unique( + need_num_repeats=True, need_new2old_indexes=True ) - seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0) + seq_to_path_shape = unique_word_seq.shape.get_layer(0) # path_to_seq_map is a 1-D torch.Tensor. # path_to_seq_map[i] is the seq to which the i-th path @@ -461,7 +462,7 @@ def rescore_with_n_best_list( # Remove the seq axis. # Now unique_word_seq has only two axes [path][word] - unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0) + unique_word_seq = unique_word_seq.remove_axis(0) # word_fsa is an FsaVec with axes [path][state][arc] word_fsa = k2.linear_fsa(unique_word_seq) @@ -485,39 +486,42 @@ def rescore_with_n_best_list( use_double_scores=True, log_semiring=False ) - path_2axes = k2.ragged.remove_axis(path, 0) + path_2axes = path.remove_axis(0) ans = dict() for lm_scale in lm_scale_list: tot_scores = am_scores / lm_scale + lm_scores - # Remember that we used `k2.ragged.unique_sequences` to remove repeated + # Remember that we used `k2.RaggedTensor.unique` to remove repeated # paths to avoid redundant computation in `k2.intersect_device`. # Now we use `num_repeats` to correct the scores for each path. # # NOTE(fangjun): It is commented out as it leads to a worse WER # tot_scores = tot_scores * num_repeats.values() - ragged_tot_scores = k2.RaggedFloat( - seq_to_path_shape, tot_scores.to(torch.float32) - ) - argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) + ragged_tot_scores = k2.RaggedTensor(seq_to_path_shape, tot_scores) + argmax_indexes = ragged_tot_scores.argmax() # Use k2.index here since argmax_indexes' dtype is torch.int32 - best_path_indexes = k2.index(new2old, argmax_indexes) + best_path_indexes = k2.index_select(new2old, argmax_indexes) # best_path is a k2.RaggedInt with 2 axes [path][arc_pos] - best_path = k2.index(path_2axes, best_path_indexes) + best_path, _ = path_2axes.index( + indexes=best_path_indexes, axis=0, need_value_indexes=False + ) - # labels is a k2.RaggedInt with 2 axes [path][phone_id] + # labels is a k2.RaggedTensor with 2 axes [path][phone_id] # Note that it contains -1s. - labels = k2.index(lattice.labels.contiguous(), best_path) + labels = k2.ragged.index(lattice.labels.contiguous(), best_path) - labels = k2.ragged.remove_values_eq(labels, -1) + labels = labels.remove_values_eq(-1) - # lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so - # aux_labels is also a k2.RaggedInt with 2 axes - aux_labels = k2.index(lattice.aux_labels, best_path.values()) + # lattice.aux_labels is a k2.RaggedTensor tensor with 2 axes, so + # aux_labels is also a k2.RaggedTensor with 2 axes + + aux_labels, _ = lattice.aux_labels.index( + indexes=best_path.data, axis=0, need_value_indexes=False + ) best_path_fsa = k2.linear_fsa(labels) best_path_fsa.aux_labels = aux_labels @@ -659,12 +663,16 @@ def nbest_oracle( scale=scale, ) - word_seq = k2.index(lattice.aux_labels, path) - word_seq = k2.ragged.remove_values_leq(word_seq, 0) - unique_word_seq, _, _ = k2.ragged.unique_sequences( - word_seq, need_num_repeats=False, need_new2old_indexes=False + if isinstance(lattice.aux_labels, torch.Tensor): + word_seq = k2.ragged.index(lattice.aux_labels, path) + else: + word_seq = lattice.aux_labels.index(path, remove_axis=True) + + word_seq = word_seq.remove_values_leq(0) + unique_word_seq, _, _ = word_seq.unique( + need_num_repeats=False, need_new2old_indexes=False ) - unique_word_ids = k2.ragged.to_list(unique_word_seq) + unique_word_ids = unique_word_seq.tolist() assert len(unique_word_ids) == len(ref_texts) # unique_word_ids[i] contains all hypotheses of the i-th utterance @@ -743,33 +751,36 @@ def rescore_with_attention_decoder( scale=scale, ) - # word_seq is a k2.RaggedInt sharing the same shape as `path` + # word_seq is a k2.RaggedTensor sharing the same shape as `path` # but it contains word IDs. Note that it also contains 0s and -1s. # The last entry in each sublist is -1. - word_seq = k2.index(lattice.aux_labels, path) + if isinstance(lattice.aux_labels, torch.Tensor): + word_seq = k2.ragged.index(lattice.aux_labels, path) + else: + word_seq = lattice.aux_labels.index(path, remove_axis=True) # Remove epsilons and -1 from word_seq - word_seq = k2.ragged.remove_values_leq(word_seq, 0) + word_seq = word_seq.remove_values_leq(0) # Remove paths that has identical word sequences. # - # unique_word_seq is still a k2.RaggedInt with 3 axes [seq][path][word] + # unique_word_seq is still a k2.RaggedTensor with 3 axes [seq][path][word] # except that there are no repeated paths with the same word_seq # within a sequence. # - # num_repeats is also a k2.RaggedInt with 2 axes containing the + # num_repeats is also a k2.RaggedTensor with 2 axes containing the # multiplicities of each path. - # num_repeats.num_elements() == unique_word_seqs.tot_size(1) + # num_repeats.numel() == unique_word_seqs.tot_size(1) # # Since k2.ragged.unique_sequences will reorder paths within a seq, # `new2old` is a 1-D torch.Tensor mapping from the output path index # to the input path index. # new2old.numel() == unique_word_seq.tot_size(1) - unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences( - word_seq, need_num_repeats=True, need_new2old_indexes=True + unique_word_seq, num_repeats, new2old = word_seq.unique( + need_num_repeats=True, need_new2old_indexes=True ) - seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0) + seq_to_path_shape = unique_word_seq.shape.get_layer(0) # path_to_seq_map is a 1-D torch.Tensor. # path_to_seq_map[i] is the seq to which the i-th path @@ -778,7 +789,7 @@ def rescore_with_attention_decoder( # Remove the seq axis. # Now unique_word_seq has only two axes [path][word] - unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0) + unique_word_seq = unique_word_seq.remove_axis(0) # word_fsa is an FsaVec with axes [path][state][arc] word_fsa = k2.linear_fsa(unique_word_seq) @@ -796,20 +807,23 @@ def rescore_with_attention_decoder( # CAUTION: The "tokens" attribute is set in the file # local/compile_hlg.py - token_seq = k2.index(lattice.tokens, path) + if isinstance(lattice.tokens, torch.Tensor): + token_seq = k2.ragged.index(lattice.tokens, path) + else: + token_seq = lattice.tokens.index(path, remove_axis=True) # Remove epsilons and -1 from token_seq - token_seq = k2.ragged.remove_values_leq(token_seq, 0) + token_seq = token_seq.remove_values_leq(0) # Remove the seq axis. - token_seq = k2.ragged.remove_axis(token_seq, 0) + token_seq = token_seq.remove_axis(0) - token_seq, _ = k2.ragged.index( - token_seq, indexes=new2old, axis=0, need_value_indexes=False + token_seq, _ = token_seq.index( + indexes=new2old, axis=0, need_value_indexes=False ) # Now word in unique_word_seq has its corresponding token IDs. - token_ids = k2.ragged.to_list(token_seq) + token_ids = token_seq.tolist() num_word_seqs = new2old.numel() @@ -849,7 +863,7 @@ def rescore_with_attention_decoder( else: attention_scale_list = [attention_scale] - path_2axes = k2.ragged.remove_axis(path, 0) + path_2axes = path.remove_axis(0) ans = dict() for n_scale in ngram_lm_scale_list: @@ -859,23 +873,28 @@ def rescore_with_attention_decoder( + n_scale * ngram_lm_scores + a_scale * attention_scores ) - ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape, tot_scores) - argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) + ragged_tot_scores = k2.RaggedTensor(seq_to_path_shape, tot_scores) + argmax_indexes = ragged_tot_scores.argmax() - best_path_indexes = k2.index(new2old, argmax_indexes) + best_path_indexes = k2.index_select(new2old, argmax_indexes) # best_path is a k2.RaggedInt with 2 axes [path][arc_pos] - best_path = k2.index(path_2axes, best_path_indexes) + best_path, _ = path_2axes.index( + indexes=best_path_indexes, axis=0, need_value_indexes=False + ) - # labels is a k2.RaggedInt with 2 axes [path][token_id] + # labels is a k2.RaggedTensor with 2 axes [path][token_id] # Note that it contains -1s. - labels = k2.index(lattice.labels.contiguous(), best_path) + labels = k2.ragged.index(lattice.labels.contiguous(), best_path) - labels = k2.ragged.remove_values_eq(labels, -1) + labels = labels.remove_values_eq(-1) - # lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so - # aux_labels is also a k2.RaggedInt with 2 axes - aux_labels = k2.index(lattice.aux_labels, best_path.values()) + if isinstance(lattice.aux_labels, torch.Tensor): + aux_labels = k2.index_select(lattice.aux_labels, best_path.data) + else: + aux_labels, _ = lattice.aux_labels.index( + indexes=best_path.data, axis=0, need_value_indexes=False + ) best_path_fsa = k2.linear_fsa(labels) best_path_fsa.aux_labels = aux_labels diff --git a/icefall/lexicon.py b/icefall/lexicon.py index f1127c7cf..6730bac49 100644 --- a/icefall/lexicon.py +++ b/icefall/lexicon.py @@ -157,7 +157,7 @@ class BpeLexicon(Lexicon): lang_dir / "lexicon.txt" ) - def convert_lexicon_to_ragged(self, filename: str) -> k2.RaggedInt: + def convert_lexicon_to_ragged(self, filename: str) -> k2.RaggedTensor: """Read a BPE lexicon from file and convert it to a k2 ragged tensor. @@ -200,19 +200,18 @@ class BpeLexicon(Lexicon): ) values = torch.tensor(token_ids, dtype=torch.int32) - return k2.RaggedInt(shape, values) + return k2.RaggedTensor(shape, values) - def words_to_piece_ids(self, words: List[str]) -> k2.RaggedInt: + def words_to_piece_ids(self, words: List[str]) -> k2.RaggedTensor: """Convert a list of words to a ragged tensor contained word piece IDs. """ word_ids = [self.word_table[w] for w in words] word_ids = torch.tensor(word_ids, dtype=torch.int32) - ragged, _ = k2.ragged.index( - self.ragged_lexicon, + ragged, _ = self.ragged_lexicon.index( indexes=word_ids, - need_value_indexes=False, axis=0, + need_value_indexes=False, ) return ragged diff --git a/icefall/utils.py b/icefall/utils.py index 2994c2d47..1130d8947 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -26,7 +26,6 @@ from pathlib import Path from typing import Dict, Iterable, List, TextIO, Tuple, Union import k2 -import k2.ragged as k2r import kaldialign import torch import torch.distributed as dist @@ -199,26 +198,25 @@ def get_texts(best_paths: k2.Fsa) -> List[List[int]]: Returns a list of lists of int, containing the label sequences we decoded. """ - if isinstance(best_paths.aux_labels, k2.RaggedInt): + if isinstance(best_paths.aux_labels, k2.RaggedTensor): # remove 0's and -1's. - aux_labels = k2r.remove_values_leq(best_paths.aux_labels, 0) - aux_shape = k2r.compose_ragged_shapes( - best_paths.arcs.shape(), aux_labels.shape() - ) + aux_labels = best_paths.aux_labels.remove_values_leq(0) + # TODO: change arcs.shape() to arcs.shape + aux_shape = best_paths.arcs.shape().compose(aux_labels.shape) # remove the states and arcs axes. - aux_shape = k2r.remove_axis(aux_shape, 1) - aux_shape = k2r.remove_axis(aux_shape, 1) - aux_labels = k2.RaggedInt(aux_shape, aux_labels.values()) + aux_shape = aux_shape.remove_axis(1) + aux_shape = aux_shape.remove_axis(1) + aux_labels = k2.RaggedTensor(aux_shape, aux_labels.data) else: # remove axis corresponding to states. - aux_shape = k2r.remove_axis(best_paths.arcs.shape(), 1) - aux_labels = k2.RaggedInt(aux_shape, best_paths.aux_labels) + aux_shape = best_paths.arcs.shape().remove_axis(1) + aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels) # remove 0's and -1's. - aux_labels = k2r.remove_values_leq(aux_labels, 0) + aux_labels = aux_labels.remove_values_leq(0) - assert aux_labels.num_axes() == 2 - return k2r.to_list(aux_labels) + assert aux_labels.num_axes == 2 + return aux_labels.tolist() def store_transcripts( diff --git a/test/test_bpe_graph_compiler.py b/test/test_bpe_graph_compiler.py index 67d300b7d..e58c4f1c6 100755 --- a/test/test_bpe_graph_compiler.py +++ b/test/test_bpe_graph_compiler.py @@ -16,9 +16,10 @@ # limitations under the License. +from pathlib import Path + from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.lexicon import BpeLexicon -from pathlib import Path def test(): diff --git a/test/test_utils.py b/test/test_utils.py index 2dd79689f..b4c9358fd 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -60,7 +60,7 @@ def test_get_texts_ragged(): 4 """ ) - fsa1.aux_labels = k2.RaggedInt("[ [1 3 0 2] [] [4 0 1] [-1]]") + fsa1.aux_labels = k2.RaggedTensor("[ [1 3 0 2] [] [4 0 1] [-1]]") fsa2 = k2.Fsa.from_str( """ @@ -70,7 +70,7 @@ def test_get_texts_ragged(): 3 """ ) - fsa2.aux_labels = k2.RaggedInt("[[3 0 5 0 8] [0 9 7 0] [-1]]") + fsa2.aux_labels = k2.RaggedTensor("[[3 0 5 0 8] [0 9 7 0] [-1]]") fsas = k2.Fsa.from_fsas([fsa1, fsa2]) texts = get_texts(fsas) assert texts == [[1, 3, 2, 4, 1], [3, 5, 8, 9, 7]] From 7f8e3a673ae4301df92859bc8e02b1f2466bc9a1 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 9 Sep 2021 13:50:31 +0800 Subject: [PATCH 05/12] Add commands for reproducing. (#40) * Add commands for reproducing. * Use --bucketing-sampler by default. --- egs/librispeech/ASR/RESULTS.md | 26 +++++++++++++++++++ .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 4 +-- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index d4acf9206..d04e912bf 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -21,6 +21,32 @@ To get more unique paths, we scaled the lattice.scores with 0.5 (see https://git |test-clean|1.3|1.2| |test-other|1.2|1.1| +You can use the following commands to reproduce our results: + +```bash +git clone https://github.com/k2-fsa/icefall +cd icefall + +# It was using ef233486, you may not need to switch to it +# git checkout ef233486 + +cd egs/librispeech/ASR +./prepare.sh + +export CUDA_VISIBLE_DEVICES="0,1,2,3" +python conformer_ctc/train.py --bucketing-sampler True \ + --concatenate-cuts False \ + --max-duration 200 \ + --full-libri True \ + --world-size 4 + +python conformer_ctc/decode.py --lattice-score-scale 0.5 \ + --epoch 34 \ + --avg 20 \ + --method attention-decoder \ + --max-duration 20 \ + --num-paths 100 +``` ### LibriSpeech training results (Tdnn-Lstm) #### 2021-08-24 diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 91c1d6a96..8290e71d1 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -82,14 +82,14 @@ class LibriSpeechAsrDataModule(DataModule): group.add_argument( "--max-duration", type=int, - default=500.0, + default=200.0, help="Maximum pooled recordings duration (seconds) in a " "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, - default=False, + default=True, help="When enabled, the batches will come from buckets of " "similar duration (saves padding frames).", ) From f792b466bfde6ccfe60cc27918628c5c31843798 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 13 Sep 2021 10:49:18 +0800 Subject: [PATCH 06/12] Change default value of lattice-score-scale from 1.0 to 0.5 (#41) * Change the default value of lattice-score-scale from 1.0 to 0.5 * Fix CI. --- .github/workflows/test.yml | 14 ++++++++++++++ egs/librispeech/ASR/conformer_ctc/decode.py | 5 +++-- egs/librispeech/ASR/tdnn_lstm_ctc/decode.py | 3 ++- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c853e3de1..c3025d730 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -53,6 +53,20 @@ jobs: # icefall requirements pip install -r requirements.txt + - name: Install graphviz + if: startsWith(matrix.os, 'ubuntu') + shell: bash + run: | + python3 -m pip install -qq graphviz + sudo apt-get -qq install graphviz + + - name: Install graphviz + if: startsWith(matrix.os, 'macos') + shell: bash + run: | + python3 -m pip install -qq graphviz + brew install -q graphviz + - name: Run tests if: startsWith(matrix.os, 'ubuntu') run: | diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index cfdcff756..85161f737 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -108,7 +108,7 @@ def get_parser(): parser.add_argument( "--lattice-score-scale", type=float, - default=1.0, + default=0.5, help="""The scale to be applied to `lattice.scores`. It's needed if you use any kinds of n-best based rescoring. Used only when "method" is one of the following values: @@ -278,7 +278,8 @@ def decode_one_batch( "attention-decoder", ] - lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] if params.method == "nbest-rescoring": diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 87e9cddb4..23b2e794c 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -206,7 +206,8 @@ def decode_one_batch( assert params.method in ["nbest-rescoring", "whole-lattice-rescoring"] - lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] if params.method == "nbest-rescoring": From 24656e9749497a1599b1d7d365e877d3464e6b83 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Mon, 13 Sep 2021 18:28:57 +0800 Subject: [PATCH 07/12] Update docs and remove unnecessary arguments (#42) * Fix typo in docs * Update docs and remove unnecessary arguments * Fix code style --- .../recipes/librispeech/conformer_ctc.rst | 12 +- .../recipes/librispeech/tdnn_lstm_ctc.rst | 90 +++++- .../ASR/conformer_ctc/conformer.py | 29 +- egs/librispeech/ASR/conformer_ctc/decode.py | 12 +- .../ASR/conformer_ctc/pretrained.py | 14 +- egs/librispeech/ASR/conformer_ctc/train.py | 58 ++-- .../ASR/conformer_ctc/transformer.py | 13 +- .../ASR/tdnn_lstm_ctc/Pre-trained.md | 270 ------------------ egs/librispeech/ASR/tdnn_lstm_ctc/decode.py | 51 +++- 9 files changed, 184 insertions(+), 365 deletions(-) delete mode 100644 egs/librispeech/ASR/tdnn_lstm_ctc/Pre-trained.md diff --git a/docs/source/recipes/librispeech/conformer_ctc.rst b/docs/source/recipes/librispeech/conformer_ctc.rst index af3e59e68..40100bc5a 100644 --- a/docs/source/recipes/librispeech/conformer_ctc.rst +++ b/docs/source/recipes/librispeech/conformer_ctc.rst @@ -45,7 +45,7 @@ For example, .. code-block:: bash - $ cd egs/yesno/ASR + $ cd egs/librispeech/ASR $ ./prepare.sh --stage 0 --stop-stage 0 means to run only stage 0. @@ -171,7 +171,7 @@ The following options are used quite often: Pre-configured options ~~~~~~~~~~~~~~~~~~~~~~ -There are some training options, e.g., learning rate, +There are some training options, e.g., weight decay, number of warmup steps, results dir, etc, that are not passed from the commandline. They are pre-configured by the function ``get_params()`` in @@ -346,6 +346,10 @@ The following commands describe how to download the pre-trained model: You have to use ``git lfs`` to download the pre-trained model. +.. CAUTION:: + + In order to use this pre-trained model, your k2 version has to be v1.7 or later. + After downloading, you will have the following files: .. code-block:: bash @@ -409,9 +413,9 @@ After downloading, you will have the following files: It contains some test sound files from LibriSpeech ``test-clean`` dataset. - - `test_waves/trans.txt` + - ``test_waves/trans.txt`` - It contains the reference transcripts for the sound files in `test_waves/`. + It contains the reference transcripts for the sound files in ``test_waves/``. The information of the test sound files is listed below: diff --git a/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst b/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst index 64f0a6a08..848026802 100644 --- a/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst +++ b/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst @@ -153,10 +153,6 @@ Some commonly used options are: will save the averaged model to ``tdnn_lstm_ctc/exp/pretrained.pt``. See :ref:`tdnn_lstm_ctc use a pre-trained model` for how to use it. -.. HINT:: - - There are several decoding methods provided in `tdnn_lstm_ctc/decode.py `_, you can change the decoding method by modifying ``method`` parameter in function ``get_params()``. - .. _tdnn_lstm_ctc use a pre-trained model: @@ -168,6 +164,16 @@ We have uploaded the pre-trained model to The following shows you how to use the pre-trained model. + +Install kaldifeat +~~~~~~~~~~~~~~~~~ + +`kaldifeat `_ is used to +extract features for a single sound file or multiple sound files +at the same time. + +Please refer to ``_ for installation. + Download the pre-trained model ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -183,6 +189,10 @@ Download the pre-trained model You have to use ``git lfs`` to download the pre-trained model. +.. CAUTION:: + + In order to use this pre-trained model, your k2 version has to be v1.7 or later. + After downloading, you will have the following files: .. code-block:: bash @@ -212,13 +222,75 @@ After downloading, you will have the following files: 6 directories, 10 files +**File descriptions**: -Download kaldifeat -~~~~~~~~~~~~~~~~~~ + - ``data/lang_phone/HLG.pt`` + + It is the decoding graph. + + - ``data/lang_phone/tokens.txt`` + + It contains tokens and their IDs. + + - ``data/lang_phone/words.txt`` + + It contains words and their IDs. + + - ``data/lm/G_4_gram.pt`` + + It is a 4-gram LM, useful for LM rescoring. + + - ``exp/pretrained.pt`` + + It contains pre-trained model parameters, obtained by averaging + checkpoints from ``epoch-14.pt`` to ``epoch-19.pt``. + Note: We have removed optimizer ``state_dict`` to reduce file size. + + - ``test_waves/*.flac`` + + It contains some test sound files from LibriSpeech ``test-clean`` dataset. + + - ``test_waves/trans.txt`` + + It contains the reference transcripts for the sound files in ``test_waves/``. + +The information of the test sound files is listed below: + +.. code-block:: bash + + $ soxi tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/*.flac + + Input File : 'tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac' + Channels : 1 + Sample Rate : 16000 + Precision : 16-bit + Duration : 00:00:06.62 = 106000 samples ~ 496.875 CDDA sectors + File Size : 116k + Bit Rate : 140k + Sample Encoding: 16-bit FLAC + + + Input File : 'tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac' + Channels : 1 + Sample Rate : 16000 + Precision : 16-bit + Duration : 00:00:16.71 = 267440 samples ~ 1253.62 CDDA sectors + File Size : 343k + Bit Rate : 164k + Sample Encoding: 16-bit FLAC + + + Input File : 'tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac' + Channels : 1 + Sample Rate : 16000 + Precision : 16-bit + Duration : 00:00:04.83 = 77200 samples ~ 361.875 CDDA sectors + File Size : 105k + Bit Rate : 174k + Sample Encoding: 16-bit FLAC + + Total Duration of 3 files: 00:00:28.16 -`kaldifeat `_ is used for extracting -features from a single or multiple sound files. Please refer to -``_ to install ``kaldifeat`` first. Inference with a pre-trained model ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index 08287d686..efe3570cb 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -56,8 +56,6 @@ class Conformer(Transformer): cnn_module_kernel: int = 31, normalize_before: bool = True, vgg_frontend: bool = False, - is_espnet_structure: bool = False, - mmi_loss: bool = True, use_feat_batchnorm: bool = False, ) -> None: super(Conformer, self).__init__( @@ -72,7 +70,6 @@ class Conformer(Transformer): dropout=dropout, normalize_before=normalize_before, vgg_frontend=vgg_frontend, - mmi_loss=mmi_loss, use_feat_batchnorm=use_feat_batchnorm, ) @@ -85,12 +82,10 @@ class Conformer(Transformer): dropout, cnn_module_kernel, normalize_before, - is_espnet_structure, ) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) self.normalize_before = normalize_before - self.is_espnet_structure = is_espnet_structure - if self.normalize_before and self.is_espnet_structure: + if self.normalize_before: self.after_norm = nn.LayerNorm(d_model) else: # Note: TorchScript detects that self.after_norm could be used inside forward() @@ -125,7 +120,7 @@ class Conformer(Transformer): mask = mask.to(x.device) x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) - if self.normalize_before and self.is_espnet_structure: + if self.normalize_before: x = self.after_norm(x) return x, mask @@ -159,11 +154,10 @@ class ConformerEncoderLayer(nn.Module): dropout: float = 0.1, cnn_module_kernel: int = 31, normalize_before: bool = True, - is_espnet_structure: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0, is_espnet_structure=is_espnet_structure + d_model, nhead, dropout=0.0 ) self.feed_forward = nn.Sequential( @@ -436,7 +430,6 @@ class RelPositionMultiheadAttention(nn.Module): embed_dim: int, num_heads: int, dropout: float = 0.0, - is_espnet_structure: bool = False, ) -> None: super(RelPositionMultiheadAttention, self).__init__() self.embed_dim = embed_dim @@ -459,8 +452,6 @@ class RelPositionMultiheadAttention(nn.Module): self._reset_parameters() - self.is_espnet_structure = is_espnet_structure - def _reset_parameters(self) -> None: nn.init.xavier_uniform_(self.in_proj.weight) nn.init.constant_(self.in_proj.bias, 0.0) @@ -690,9 +681,6 @@ class RelPositionMultiheadAttention(nn.Module): _b = _b[_start:] v = nn.functional.linear(value, _w, _b) - if not self.is_espnet_structure: - q = q * scaling - if attn_mask is not None: assert ( attn_mask.dtype == torch.float32 @@ -785,14 +773,9 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - if not self.is_espnet_structure: - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) - else: - attn_output_weights = ( - matrix_ac + matrix_bd - ) * scaling # (batch, head, time1, time2) + attn_output_weights = ( + matrix_ac + matrix_bd + ) * scaling # (batch, head, time1, time2) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, -1 diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 85161f737..c9d31ff6c 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -137,15 +137,15 @@ def get_params() -> AttributeDict: "exp_dir": Path("conformer_ctc/exp"), "lang_dir": Path("data/lang_bpe"), "lm_dir": Path("data/lm"), + # parameters for conformer + "subsampling_factor": 4, + "vgg_frontend": False, + "use_feat_batchnorm": True, "feature_dim": 80, "nhead": 8, "attention_dim": 512, - "subsampling_factor": 4, "num_decoder_layers": 6, - "vgg_frontend": False, - "is_espnet_structure": True, - "mmi_loss": False, - "use_feat_batchnorm": True, + # parameters for decoding "search_beam": 20, "output_beam": 8, "min_active_states": 30, @@ -538,8 +538,6 @@ def main(): subsampling_factor=params.subsampling_factor, num_decoder_layers=params.num_decoder_layers, vgg_frontend=params.vgg_frontend, - is_espnet_structure=params.is_espnet_structure, - mmi_loss=params.mmi_loss, use_feat_batchnorm=params.use_feat_batchnorm, ) diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index 95029fadb..913088777 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -173,17 +173,17 @@ def get_parser(): def get_params() -> AttributeDict: params = AttributeDict( { + "sample_rate": 16000, + # parameters for conformer + "subsampling_factor": 4, + "vgg_frontend": False, + "use_feat_batchnorm": True, "feature_dim": 80, "nhead": 8, "num_classes": 5000, - "sample_rate": 16000, "attention_dim": 512, - "subsampling_factor": 4, "num_decoder_layers": 6, - "vgg_frontend": False, - "is_espnet_structure": True, - "mmi_loss": False, - "use_feat_batchnorm": True, + # parameters for decoding "search_beam": 20, "output_beam": 8, "min_active_states": 30, @@ -241,8 +241,6 @@ def main(): subsampling_factor=params.subsampling_factor, num_decoder_layers=params.num_decoder_layers, vgg_frontend=params.vgg_frontend, - is_espnet_structure=params.is_espnet_structure, - mmi_loss=params.mmi_loss, use_feat_batchnorm=params.use_feat_batchnorm, ) diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index b0dbe72ad..298b74112 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -111,15 +112,6 @@ def get_params() -> AttributeDict: - lang_dir: It contains language related input files such as "lexicon.txt" - - lr: It specifies the initial learning rate - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - weight_decay: The weight_decay for the optimizer. - - - subsampling_factor: The subsampling factor for the model. - - best_train_loss: Best training loss so far. It is used to select the model that has the lowest training loss. It is updated during the training. @@ -138,23 +130,40 @@ def get_params() -> AttributeDict: - log_interval: Print training loss if batch_idx % log_interval` is 0 + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + - valid_interval: Run validation if batch_idx % valid_interval is 0 - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - use_feat_batchnorm: Whether to do batch normalization for the + input features. + + - attention_dim: Hidden dim for multi-head attention model. + + - head: Number of heads of multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. - beam_size: It is used in k2.ctc_loss - reduction: It is used in k2.ctc_loss - use_double_scores: It is used in k2.ctc_loss + + - weight_decay: The weight_decay for the optimizer. + + - lr_factor: The lr_factor for Noam optimizer. + + - warm_step: The warm_step for Noam optimizer. """ params = AttributeDict( { "exp_dir": Path("conformer_ctc/exp"), "lang_dir": Path("data/lang_bpe"), - "feature_dim": 80, - "weight_decay": 1e-6, - "subsampling_factor": 4, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, @@ -163,17 +172,20 @@ def get_params() -> AttributeDict: "log_interval": 10, "reset_interval": 200, "valid_interval": 3000, - "beam_size": 10, - "reduction": "sum", - "use_double_scores": True, - "accum_grad": 1, - "att_rate": 0.7, + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "use_feat_batchnorm": True, "attention_dim": 512, "nhead": 8, "num_decoder_layers": 6, - "is_espnet_structure": True, - "mmi_loss": False, - "use_feat_batchnorm": True, + # parameters for loss + "beam_size": 10, + "reduction": "sum", + "use_double_scores": True, + "att_rate": 0.7, + # parameters for Noam + "weight_decay": 1e-6, "lr_factor": 5.0, "warm_step": 80000, } @@ -646,8 +658,6 @@ def run(rank, world_size, args): subsampling_factor=params.subsampling_factor, num_decoder_layers=params.num_decoder_layers, vgg_frontend=False, - is_espnet_structure=params.is_espnet_structure, - mmi_loss=params.mmi_loss, use_feat_batchnorm=params.use_feat_batchnorm, ) diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index 191d2d612..88b10b23d 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -41,7 +41,6 @@ class Transformer(nn.Module): dropout: float = 0.1, normalize_before: bool = True, vgg_frontend: bool = False, - mmi_loss: bool = True, use_feat_batchnorm: bool = False, ) -> None: """ @@ -70,7 +69,6 @@ class Transformer(nn.Module): If True, use pre-layer norm; False to use post-layer norm. vgg_frontend: True to use vgg style frontend for subsampling. - mmi_loss: use_feat_batchnorm: True to use batchnorm for the input layer. """ @@ -122,14 +120,9 @@ class Transformer(nn.Module): ) if num_decoder_layers > 0: - if mmi_loss: - self.decoder_num_class = ( - self.num_classes + 1 - ) # +1 for the sos/eos symbol - else: - self.decoder_num_class = ( - self.num_classes - ) # bpe model already has sos/eos symbol + self.decoder_num_class = ( + self.num_classes + ) # bpe model already has sos/eos symbol self.decoder_embed = nn.Embedding( num_embeddings=self.decoder_num_class, embedding_dim=d_model diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/Pre-trained.md b/egs/librispeech/ASR/tdnn_lstm_ctc/Pre-trained.md deleted file mode 100644 index 83e98b37c..000000000 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/Pre-trained.md +++ /dev/null @@ -1,270 +0,0 @@ - -# How to use a pre-trained model to transcribe a sound file or multiple sound files - -(See the bottom of this document for the link to a colab notebook.) - -You need to prepare 4 files: - - - a model checkpoint file, e.g., epoch-20.pt - - HLG.pt, the decoding graph - - words.txt, the word symbol table - - a sound file, whose sampling rate has to be 16 kHz. - Supported formats are those supported by `torchaudio.load()`, - e.g., wav and flac. - -Also, you need to install `kaldifeat`. Please refer to - for installation. - -```bash -./tdnn_lstm_ctc/pretrained.py --help -``` - -displays the help information. - -## HLG decoding - -Once you have the above files ready and have `kaldifeat` installed, -you can run: - -```bash -./tdnn_lstm_ctc/pretrained.py \ - --checkpoint /path/to/your/checkpoint.pt \ - --words-file /path/to/words.txt \ - --HLG /path/to/HLG.pt \ - /path/to/your/sound.wav -``` - -and you will see the transcribed result. - -If you want to transcribe multiple files at the same time, you can use: - -```bash -./tdnn_lstm_ctc/pretrained.py \ - --checkpoint /path/to/your/checkpoint.pt \ - --words-file /path/to/words.txt \ - --HLG /path/to/HLG.pt \ - /path/to/your/sound1.wav \ - /path/to/your/sound2.wav \ - /path/to/your/sound3.wav -``` - -**Note**: This is the fastest decoding method. - -## HLG decoding + LM rescoring - -`./tdnn_lstm_ctc/pretrained.py` also supports `whole lattice LM rescoring`. - -To use whole lattice LM rescoring, you also need the following files: - - - G.pt, e.g., `data/lm/G_4_gram.pt` if you have run `./prepare.sh` - -The command to run decoding with LM rescoring is: - -```bash -./tdnn_lstm_ctc/pretrained.py \ - --checkpoint /path/to/your/checkpoint.pt \ - --words-file /path/to/words.txt \ - --HLG /path/to/HLG.pt \ - --method whole-lattice-rescoring \ - --G data/lm/G_4_gram.pt \ - --ngram-lm-scale 0.8 \ - /path/to/your/sound1.wav \ - /path/to/your/sound2.wav \ - /path/to/your/sound3.wav -``` - -# Decoding with a pre-trained model in action - -We have uploaded a pre-trained model to - -The following shows the steps about the usage of the provided pre-trained model. - -### (1) Download the pre-trained model - -```bash -sudo apt-get install git-lfs -cd /path/to/icefall/egs/librispeech/ASR -git lfs install -mkdir tmp -cd tmp -git clone https://huggingface.co/pkufool/icefall_asr_librispeech_tdnn-lstm_ctc -``` - -**CAUTION**: You have to install `git-lfs` to download the pre-trained model. - -You will find the following files: - -``` -tmp/ -`-- icefall_asr_librispeech_tdnn-lstm_ctc - |-- README.md - |-- data - | |-- lang_phone - | | |-- HLG.pt - | | |-- tokens.txt - | | `-- words.txt - | `-- lm - | `-- G_4_gram.pt - |-- exp - | `-- pretrained.pt - `-- test_wavs - |-- 1089-134686-0001.flac - |-- 1221-135766-0001.flac - |-- 1221-135766-0002.flac - `-- trans.txt - -6 directories, 10 files -``` - -**File descriptions**: - - - `data/lang_phone/HLG.pt` - - It is the decoding graph. - - - `data/lang_phone/tokens.txt` - - It contains tokens and their IDs. - - - `data/lang_phone/words.txt` - - It contains words and their IDs. - - - `data/lm/G_4_gram.pt` - - It is a 4-gram LM, useful for LM rescoring. - - - `exp/pretrained.pt` - - It contains pre-trained model parameters, obtained by averaging - checkpoints from `epoch-14.pt` to `epoch-19.pt`. - Note: We have removed optimizer `state_dict` to reduce file size. - - - `test_waves/*.flac` - - It contains some test sound files from LibriSpeech `test-clean` dataset. - - - `test_waves/trans.txt` - - It contains the reference transcripts for the sound files in `test_waves/`. - -The information of the test sound files is listed below: - -``` -$ soxi tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/*.flac - -Input File : 'tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac' -Channels : 1 -Sample Rate : 16000 -Precision : 16-bit -Duration : 00:00:06.62 = 106000 samples ~ 496.875 CDDA sectors -File Size : 116k -Bit Rate : 140k -Sample Encoding: 16-bit FLAC - - -Input File : 'tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac' -Channels : 1 -Sample Rate : 16000 -Precision : 16-bit -Duration : 00:00:16.71 = 267440 samples ~ 1253.62 CDDA sectors -File Size : 343k -Bit Rate : 164k -Sample Encoding: 16-bit FLAC - - -Input File : 'tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac' -Channels : 1 -Sample Rate : 16000 -Precision : 16-bit -Duration : 00:00:04.83 = 77200 samples ~ 361.875 CDDA sectors -File Size : 105k -Bit Rate : 174k -Sample Encoding: 16-bit FLAC - -Total Duration of 3 files: 00:00:28.16 -``` - -### (2) Use HLG decoding - -```bash -cd /path/to/icefall/egs/librispeech/ASR - -./tdnn_lstm_ctc/pretrained.py \ - --checkpoint ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/exp/pretraind.pt \ - --words-file ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/words.txt \ - --HLG ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt \ - ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac \ - ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac \ - ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac -``` - -The output is given below: - -``` -2021-08-24 16:57:13,315 INFO [pretrained.py:168] device: cuda:0 -2021-08-24 16:57:13,315 INFO [pretrained.py:170] Creating model -2021-08-24 16:57:18,331 INFO [pretrained.py:182] Loading HLG from ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt -2021-08-24 16:57:27,581 INFO [pretrained.py:199] Constructing Fbank computer -2021-08-24 16:57:27,584 INFO [pretrained.py:209] Reading sound files: ['./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac'] -2021-08-24 16:57:27,599 INFO [pretrained.py:215] Decoding started -2021-08-24 16:57:27,791 INFO [pretrained.py:245] Use HLG decoding -2021-08-24 16:57:28,098 INFO [pretrained.py:266] -./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac: -AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS - -./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac: -GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN - -./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac: -YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION - - -2021-08-24 16:57:28,099 INFO [pretrained.py:268] Decoding Done -``` - -### (3) Use HLG decoding + LM rescoring - -```bash -./tdnn_lstm_ctc/pretrained.py \ - --checkpoint ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/exp/pretraind.pt \ - --words-file ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/words.txt \ - --HLG ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt \ - --method whole-lattice-rescoring \ - --G ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lm/G_4_gram.pt \ - --ngram-lm-scale 0.8 \ - ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac \ - ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac \ - ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac -``` - -The output is: - -``` -2021-08-24 16:39:24,725 INFO [pretrained.py:168] device: cuda:0 -2021-08-24 16:39:24,725 INFO [pretrained.py:170] Creating model -2021-08-24 16:39:29,403 INFO [pretrained.py:182] Loading HLG from ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt -2021-08-24 16:39:40,631 INFO [pretrained.py:190] Loading G from ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lm/G_4_gram.pt -2021-08-24 16:39:53,098 INFO [pretrained.py:199] Constructing Fbank computer -2021-08-24 16:39:53,107 INFO [pretrained.py:209] Reading sound files: ['./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac'] -2021-08-24 16:39:53,121 INFO [pretrained.py:215] Decoding started -2021-08-24 16:39:53,443 INFO [pretrained.py:250] Use HLG decoding + LM rescoring -2021-08-24 16:39:54,010 INFO [pretrained.py:266] -./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac: -AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS - -./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac: -GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN - -./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac: -YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION - - -2021-08-24 16:39:54,010 INFO [pretrained.py:268] Decoding Done -``` - -**NOTE**: We provide a colab notebook for demonstration. -[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing) - -Due to limited memory provided by Colab, you have to upgrade to Colab Pro to run `HLG decoding + LM rescoring`. -Otherwise, you can only run `HLG decoding` with Colab. diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 23b2e794c..7e5ec8c0d 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -67,6 +67,47 @@ def get_parser(): "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", ) + parser.add_argument( + "--method", + type=str, + default="whole-lattice-rescoring", + help="""Decoding method. + Supported values are: + - (1) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (2) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (3) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (4) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring + """, + ) + + parser.add_argument( + "--lattice-score-scale", + type=float, + default=0.5, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring + A smaller value results in more unique paths. + """, + ) + parser.add_argument( "--export", type=str2bool, @@ -93,16 +134,6 @@ def get_params() -> AttributeDict: "min_active_states": 30, "max_active_states": 10000, "use_double_scores": True, - # Possible values for method: - # - 1best - # - nbest - # - nbest-rescoring - # - whole-lattice-rescoring - "method": "whole-lattice-rescoring", - # "method": "1best", - # "method": "nbest", - # num_paths is used when method is "nbest" and "nbest-rescoring" - "num_paths": 100, } ) return params From a2be2896a95b59f10ffb8b8feb7c6c592bb33474 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 14 Sep 2021 13:39:52 +0800 Subject: [PATCH 08/12] Fix the link to k2's installation doc. (#46) --- docs/source/installation/index.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/installation/index.rst b/docs/source/installation/index.rst index c11cbd1be..588ec13ec 100644 --- a/docs/source/installation/index.rst +++ b/docs/source/installation/index.rst @@ -35,7 +35,7 @@ installs its dependency PyTorch, which can be reused by ``lhotse``. (1) Install k2 -------------- -Please refer to ``_ +Please refer to ``_ to install ``k2``. .. CAUTION:: From 9a6e0489c8ee5a3337c029595dd3ead9bcf23c91 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Tue, 14 Sep 2021 16:39:56 +0800 Subject: [PATCH 09/12] update api for RaggedTensor (#45) * Fix code style * update k2 version in CI * fix compile hlg --- .github/workflows/run-yesno-recipe.yml | 2 +- .github/workflows/test.yml | 2 +- egs/librispeech/ASR/local/compile_hlg.py | 2 +- egs/yesno/ASR/local/compile_hlg.py | 2 +- icefall/decode.py | 25 +++++++++++++++--------- icefall/utils.py | 2 +- 6 files changed, 21 insertions(+), 14 deletions(-) diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml index b4e266672..448ec3e32 100644 --- a/.github/workflows/run-yesno-recipe.yml +++ b/.github/workflows/run-yesno-recipe.yml @@ -56,7 +56,7 @@ jobs: run: | python3 -m pip install --upgrade pip black flake8 python3 -m pip install -U pip - python3 -m pip install k2==1.7.dev20210908+cpu.torch1.7.1 -f https://k2-fsa.org/nightly/ + python3 -m pip install k2==1.7.dev20210914+cpu.torch1.7.1 -f https://k2-fsa.org/nightly/ python3 -m pip install torchaudio==0.7.2 python3 -m pip install git+https://github.com/lhotse-speech/lhotse diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c3025d730..13b7742d9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -32,7 +32,7 @@ jobs: os: [ubuntu-18.04, macos-10.15] python-version: [3.6, 3.7, 3.8, 3.9] torch: ["1.8.1"] - k2-version: ["1.7.dev20210908"] + k2-version: ["1.7.dev20210914"] fail-fast: false diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index 407fb7d88..098d5d6a3 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -103,7 +103,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: LG.labels[LG.labels >= first_token_disambig_id] = 0 assert isinstance(LG.aux_labels, k2.RaggedTensor) - LG.aux_labels.data[LG.aux_labels.data >= first_word_disambig_id] = 0 + LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 LG = k2.remove_epsilon(LG) logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") diff --git a/egs/yesno/ASR/local/compile_hlg.py b/egs/yesno/ASR/local/compile_hlg.py index 41a927455..9b6a4c5ba 100755 --- a/egs/yesno/ASR/local/compile_hlg.py +++ b/egs/yesno/ASR/local/compile_hlg.py @@ -81,7 +81,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: LG.labels[LG.labels >= first_token_disambig_id] = 0 assert isinstance(LG.aux_labels, k2.RaggedTensor) - LG.aux_labels.data[LG.aux_labels.data >= first_word_disambig_id] = 0 + LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 LG = k2.remove_epsilon(LG) logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") diff --git a/icefall/decode.py b/icefall/decode.py index 3f6e5fc84..dfac5700e 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -221,7 +221,8 @@ def nbest_decoding( if isinstance(lattice.aux_labels, torch.Tensor): word_seq = k2.ragged.index(lattice.aux_labels, path) else: - word_seq = lattice.aux_labels.index(path, remove_axis=True) + word_seq = lattice.aux_labels.index(path) + word_seq = word_seq.remove_axis(1) # Remove 0 (epsilon) and -1 from word_seq word_seq = word_seq.remove_values_leq(0) @@ -300,7 +301,7 @@ def nbest_decoding( # lattice.aux_labels is a k2.RaggedTensor with 2 axes, so # aux_labels is also a k2.RaggedTensor with 2 axes aux_labels, _ = lattice.aux_labels.index( - indexes=best_path.data, axis=0, need_value_indexes=False + indexes=best_path.values, axis=0, need_value_indexes=False ) best_path_fsa = k2.linear_fsa(labels) @@ -430,7 +431,8 @@ def rescore_with_n_best_list( if isinstance(lattice.aux_labels, torch.Tensor): word_seq = k2.ragged.index(lattice.aux_labels, path) else: - word_seq = lattice.aux_labels.index(path, remove_axis=True) + word_seq = lattice.aux_labels.index(path) + word_seq = word_seq.remove_axis(1) # Remove epsilons and -1 from word_seq word_seq = word_seq.remove_values_leq(0) @@ -520,7 +522,7 @@ def rescore_with_n_best_list( # aux_labels is also a k2.RaggedTensor with 2 axes aux_labels, _ = lattice.aux_labels.index( - indexes=best_path.data, axis=0, need_value_indexes=False + indexes=best_path.values, axis=0, need_value_indexes=False ) best_path_fsa = k2.linear_fsa(labels) @@ -666,7 +668,8 @@ def nbest_oracle( if isinstance(lattice.aux_labels, torch.Tensor): word_seq = k2.ragged.index(lattice.aux_labels, path) else: - word_seq = lattice.aux_labels.index(path, remove_axis=True) + word_seq = lattice.aux_labels.index(path) + word_seq = word_seq.remove_axis(1) word_seq = word_seq.remove_values_leq(0) unique_word_seq, _, _ = word_seq.unique( @@ -757,7 +760,8 @@ def rescore_with_attention_decoder( if isinstance(lattice.aux_labels, torch.Tensor): word_seq = k2.ragged.index(lattice.aux_labels, path) else: - word_seq = lattice.aux_labels.index(path, remove_axis=True) + word_seq = lattice.aux_labels.index(path) + word_seq = word_seq.remove_axis(1) # Remove epsilons and -1 from word_seq word_seq = word_seq.remove_values_leq(0) @@ -810,7 +814,8 @@ def rescore_with_attention_decoder( if isinstance(lattice.tokens, torch.Tensor): token_seq = k2.ragged.index(lattice.tokens, path) else: - token_seq = lattice.tokens.index(path, remove_axis=True) + token_seq = lattice.tokens.index(path) + token_seq = token_seq.remove_axis(1) # Remove epsilons and -1 from token_seq token_seq = token_seq.remove_values_leq(0) @@ -890,10 +895,12 @@ def rescore_with_attention_decoder( labels = labels.remove_values_eq(-1) if isinstance(lattice.aux_labels, torch.Tensor): - aux_labels = k2.index_select(lattice.aux_labels, best_path.data) + aux_labels = k2.index_select( + lattice.aux_labels, best_path.values + ) else: aux_labels, _ = lattice.aux_labels.index( - indexes=best_path.data, axis=0, need_value_indexes=False + indexes=best_path.values, axis=0, need_value_indexes=False ) best_path_fsa = k2.linear_fsa(labels) diff --git a/icefall/utils.py b/icefall/utils.py index 1130d8947..cc658ae32 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -207,7 +207,7 @@ def get_texts(best_paths: k2.Fsa) -> List[List[int]]: # remove the states and arcs axes. aux_shape = aux_shape.remove_axis(1) aux_shape = aux_shape.remove_axis(1) - aux_labels = k2.RaggedTensor(aux_shape, aux_labels.data) + aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values) else: # remove axis corresponding to states. aux_shape = best_paths.arcs.shape().remove_axis(1) From cc77cb3459e2a34b542d55ced6127a6b0372c14b Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 17 Sep 2021 16:49:03 +0800 Subject: [PATCH 10/12] Fix decode.py to remove the correct axis. (#50) * Fix decode.py to remove the correct axis. * Run GitHub actions manually. --- .github/workflows/run-yesno-recipe.yml | 10 +++++----- .github/workflows/test.yml | 6 +++--- icefall/decode.py | 10 +++++----- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml index 448ec3e32..edd3d39ce 100644 --- a/.github/workflows/run-yesno-recipe.yml +++ b/.github/workflows/run-yesno-recipe.yml @@ -21,11 +21,11 @@ on: branches: - master pull_request: - branches: - - master + types: [labeled] jobs: run-yesno-recipe: + if: github.event.label.name == 'ready' || github.event_name == 'push' runs-on: ${{ matrix.os }} strategy: matrix: @@ -33,6 +33,8 @@ jobs: # TODO: enable macOS for CPU testing os: [ubuntu-18.04] python-version: [3.8] + torch: ["1.8.1"] + k2-version: ["1.8.dev20210917"] fail-fast: false steps: @@ -54,10 +56,8 @@ jobs: - name: Install Python dependencies run: | - python3 -m pip install --upgrade pip black flake8 python3 -m pip install -U pip - python3 -m pip install k2==1.7.dev20210914+cpu.torch1.7.1 -f https://k2-fsa.org/nightly/ - python3 -m pip install torchaudio==0.7.2 + pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/ python3 -m pip install git+https://github.com/lhotse-speech/lhotse # We are in ./icefall and there is a file: requirements.txt in it diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 13b7742d9..6da27170c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,18 +21,18 @@ on: branches: - master pull_request: - branches: - - master + types: [labeled] jobs: test: + if: github.event.label.name == 'ready' || github.event_name == 'push' runs-on: ${{ matrix.os }} strategy: matrix: os: [ubuntu-18.04, macos-10.15] python-version: [3.6, 3.7, 3.8, 3.9] torch: ["1.8.1"] - k2-version: ["1.7.dev20210914"] + k2-version: ["1.8.dev20210917"] fail-fast: false diff --git a/icefall/decode.py b/icefall/decode.py index dfac5700e..29b76d973 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -222,7 +222,7 @@ def nbest_decoding( word_seq = k2.ragged.index(lattice.aux_labels, path) else: word_seq = lattice.aux_labels.index(path) - word_seq = word_seq.remove_axis(1) + word_seq = word_seq.remove_axis(word_seq.num_axes - 2) # Remove 0 (epsilon) and -1 from word_seq word_seq = word_seq.remove_values_leq(0) @@ -432,7 +432,7 @@ def rescore_with_n_best_list( word_seq = k2.ragged.index(lattice.aux_labels, path) else: word_seq = lattice.aux_labels.index(path) - word_seq = word_seq.remove_axis(1) + word_seq = word_seq.remove_axis(word_seq.num_axes - 2) # Remove epsilons and -1 from word_seq word_seq = word_seq.remove_values_leq(0) @@ -669,7 +669,7 @@ def nbest_oracle( word_seq = k2.ragged.index(lattice.aux_labels, path) else: word_seq = lattice.aux_labels.index(path) - word_seq = word_seq.remove_axis(1) + word_seq = word_seq.remove_axis(word_seq.num_axes - 2) word_seq = word_seq.remove_values_leq(0) unique_word_seq, _, _ = word_seq.unique( @@ -761,7 +761,7 @@ def rescore_with_attention_decoder( word_seq = k2.ragged.index(lattice.aux_labels, path) else: word_seq = lattice.aux_labels.index(path) - word_seq = word_seq.remove_axis(1) + word_seq = word_seq.remove_axis(word_seq.num_axes - 2) # Remove epsilons and -1 from word_seq word_seq = word_seq.remove_values_leq(0) @@ -815,7 +815,7 @@ def rescore_with_attention_decoder( token_seq = k2.ragged.index(lattice.tokens, path) else: token_seq = lattice.tokens.index(path) - token_seq = token_seq.remove_axis(1) + token_seq = token_seq.remove_axis(token_seq.num_axes - 2) # Remove epsilons and -1 from token_seq token_seq = token_seq.remove_values_leq(0) From a80e58e15d9d7b2639f1cc5f2a5997634344aa80 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 20 Sep 2021 15:44:54 +0800 Subject: [PATCH 11/12] Refactor decode.py to make it more readable and more modular. (#44) * Refactor decode.py to make it more readable and more modular. * Fix an error. Nbest.fsa should always have token IDs as labels and word IDs as aux_labels. * Add nbest decoding. * Compute edit distance with k2. * Refactor nbest-oracle. * Add rescore with nbest lists. * Add whole-lattice rescoring. * Add rescoring with attention decoder. * Refactoring. * Fixes after refactoring. * Fix a typo. * Minor fixes. * Replace [] with () for shapes. * Use k2 v1.9 * Use Levenshtein graphs/alignment from k2 v1.9 * [doc] Require k2 >= v1.9 * Minor fixes. --- .github/workflows/run-yesno-recipe.yml | 2 +- .github/workflows/test.yml | 2 +- docs/source/installation/images/k2-v-1.7.svg | 1 - .../images/k2-v1.9-blueviolet.svg | 1 + docs/source/installation/index.rst | 4 +- .../ASR/conformer_ctc/conformer.py | 2 +- egs/librispeech/ASR/conformer_ctc/decode.py | 43 +- .../ASR/conformer_ctc/pretrained.py | 2 +- .../ASR/conformer_ctc/subsampling.py | 32 +- egs/librispeech/ASR/conformer_ctc/train.py | 4 +- .../ASR/conformer_ctc/transformer.py | 44 +- egs/librispeech/ASR/tdnn_lstm_ctc/decode.py | 12 +- .../ASR/tdnn_lstm_ctc/pretrained.py | 4 +- egs/librispeech/ASR/tdnn_lstm_ctc/train.py | 6 +- egs/yesno/ASR/tdnn/decode.py | 4 +- egs/yesno/ASR/tdnn/train.py | 4 +- icefall/decode.py | 1065 ++++++++--------- icefall/graph_compiler.py | 2 +- icefall/utils.py | 12 +- test/test_decode.py | 62 + 20 files changed, 688 insertions(+), 620 deletions(-) delete mode 100644 docs/source/installation/images/k2-v-1.7.svg create mode 100644 docs/source/installation/images/k2-v1.9-blueviolet.svg create mode 100644 test/test_decode.py diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml index edd3d39ce..876b95e71 100644 --- a/.github/workflows/run-yesno-recipe.yml +++ b/.github/workflows/run-yesno-recipe.yml @@ -34,7 +34,7 @@ jobs: os: [ubuntu-18.04] python-version: [3.8] torch: ["1.8.1"] - k2-version: ["1.8.dev20210917"] + k2-version: ["1.9.dev20210919"] fail-fast: false steps: diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6da27170c..150b5258a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -32,7 +32,7 @@ jobs: os: [ubuntu-18.04, macos-10.15] python-version: [3.6, 3.7, 3.8, 3.9] torch: ["1.8.1"] - k2-version: ["1.8.dev20210917"] + k2-version: ["1.9.dev20210919"] fail-fast: false diff --git a/docs/source/installation/images/k2-v-1.7.svg b/docs/source/installation/images/k2-v-1.7.svg deleted file mode 100644 index 8a74d0b55..000000000 --- a/docs/source/installation/images/k2-v-1.7.svg +++ /dev/null @@ -1 +0,0 @@ -k2: >= v1.7k2>= v1.7 diff --git a/docs/source/installation/images/k2-v1.9-blueviolet.svg b/docs/source/installation/images/k2-v1.9-blueviolet.svg new file mode 100644 index 000000000..5a207b370 --- /dev/null +++ b/docs/source/installation/images/k2-v1.9-blueviolet.svg @@ -0,0 +1 @@ +k2: v1.9k2v1.9 \ No newline at end of file diff --git a/docs/source/installation/index.rst b/docs/source/installation/index.rst index 588ec13ec..f960033e8 100644 --- a/docs/source/installation/index.rst +++ b/docs/source/installation/index.rst @@ -21,7 +21,7 @@ Installation .. |torch_versions| image:: ./images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg :alt: Supported PyTorch versions -.. |k2_versions| image:: ./images/k2-v-1.7.svg +.. |k2_versions| image:: ./images/k2-v1.9-blueviolet.svg :alt: Supported k2 versions ``icefall`` depends on `k2 `_ and @@ -40,7 +40,7 @@ to install ``k2``. .. CAUTION:: - You need to install ``k2`` with a version at least **v1.7**. + You need to install ``k2`` with a version at least **v1.9**. .. HINT:: diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index efe3570cb..b19b94db1 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -98,7 +98,7 @@ class Conformer(Transformer): """ Args: x: - The model input. Its shape is [N, T, C]. + The model input. Its shape is (N, T, C). supervisions: Supervision in lhotse format. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index c9d31ff6c..b5b41c82e 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -213,12 +213,12 @@ def decode_one_batch( feature = batch["inputs"] assert feature.ndim == 3 feature = feature.to(device) - # at entry, feature is [N, T, C] + # at entry, feature is (N, T, C) supervisions = batch["supervisions"] nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) - # nnet_output is [N, T, C] + # nnet_output is (N, T, C) supervision_segments = torch.stack( ( @@ -244,14 +244,19 @@ def decode_one_batch( # Note: You can also pass rescored lattices to it. # We choose the HLG decoded lattice for speed reasons # as HLG decoding is faster and the oracle WER - # is slightly worse than that of rescored lattices. - return nbest_oracle( + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( lattice=lattice, num_paths=params.num_paths, ref_texts=supervisions["text"], word_table=word_table, - scale=params.lattice_score_scale, + lattice_score_scale=params.lattice_score_scale, + oov="", ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa + return {key: hyps} if params.method in ["1best", "nbest"]: if params.method == "1best": @@ -264,7 +269,7 @@ def decode_one_batch( lattice=lattice, num_paths=params.num_paths, use_double_scores=params.use_double_scores, - scale=params.lattice_score_scale, + lattice_score_scale=params.lattice_score_scale, ) key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa @@ -288,17 +293,23 @@ def decode_one_batch( G=G, num_paths=params.num_paths, lm_scale_list=lm_scale_list, - scale=params.lattice_score_scale, + lattice_score_scale=params.lattice_score_scale, ) elif params.method == "whole-lattice-rescoring": best_path_dict = rescore_with_whole_lattice( - lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, ) elif params.method == "attention-decoder": # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. rescored_lattice = rescore_with_whole_lattice( - lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, ) + # TODO: pass `lattice` instead of `rescored_lattice` to + # `rescore_with_attention_decoder` best_path_dict = rescore_with_attention_decoder( lattice=rescored_lattice, @@ -308,16 +319,20 @@ def decode_one_batch( memory_key_padding_mask=memory_key_padding_mask, sos_id=sos_id, eos_id=eos_id, - scale=params.lattice_score_scale, + lattice_score_scale=params.lattice_score_scale, ) else: assert False, f"Unsupported decoding method: {params.method}" ans = dict() - for lm_scale_str, best_path in best_path_dict.items(): - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - ans[lm_scale_str] = hyps + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + for lm_scale in lm_scale_list: + ans[lm_scale_str] = [[] * lattice.shape[0]] return ans diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index 913088777..c924b87bb 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -336,7 +336,7 @@ def main(): memory_key_padding_mask=memory_key_padding_mask, sos_id=params.sos_id, eos_id=params.eos_id, - scale=params.lattice_score_scale, + lattice_score_scale=params.lattice_score_scale, ngram_lm_scale=params.ngram_lm_scale, attention_scale=params.attention_decoder_scale, ) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 720ed6c22..542fb0364 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -22,8 +22,8 @@ import torch.nn as nn class Conv2dSubsampling(nn.Module): """Convolutional 2D subsampling (to 1/4 length). - Convert an input of shape [N, T, idim] to an output - with shape [N, T', odim], where + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 It is based on @@ -34,10 +34,10 @@ class Conv2dSubsampling(nn.Module): """ Args: idim: - Input dim. The input shape is [N, T, idim]. + Input dim. The input shape is (N, T, idim). Caution: It requires: T >=7, idim >=7 odim: - Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim] + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim) """ assert idim >= 7 super().__init__() @@ -58,18 +58,18 @@ class Conv2dSubsampling(nn.Module): Args: x: - Its shape is [N, T, idim]. + Its shape is (N, T, idim). Returns: - Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim] + Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) """ - # On entry, x is [N, T, idim] - x = x.unsqueeze(1) # [N, T, idim] -> [N, 1, T, idim] i.e., [N, C, H, W] + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) x = self.conv(x) - # Now x is of shape [N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2] + # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) - # Now x is of shape [N, ((T-1)//2 - 1))//2, odim] + # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) return x @@ -80,8 +80,8 @@ class VggSubsampling(nn.Module): This paper is not 100% explicit so I am guessing to some extent, and trying to compare with other VGG implementations. - Convert an input of shape [N, T, idim] to an output - with shape [N, T', odim], where + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where T' = ((T-1)//2 - 1)//2, which approximates T' = T//4 """ @@ -93,10 +93,10 @@ class VggSubsampling(nn.Module): Args: idim: - Input dim. The input shape is [N, T, idim]. + Input dim. The input shape is (N, T, idim). Caution: It requires: T >=7, idim >=7 odim: - Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim] + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim) """ super().__init__() @@ -149,10 +149,10 @@ class VggSubsampling(nn.Module): Args: x: - Its shape is [N, T, idim]. + Its shape is (N, T, idim). Returns: - Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim] + Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) """ x = x.unsqueeze(1) x = self.layers(x) diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 298b74112..80b2d924a 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -310,14 +310,14 @@ def compute_loss( """ device = graph_compiler.device feature = batch["inputs"] - # at entry, feature is [N, T, C] + # at entry, feature is (N, T, C) assert feature.ndim == 3 feature = feature.to(device) supervisions = batch["supervisions"] with torch.set_grad_enabled(is_training): nnet_output, encoder_memory, memory_mask = model(feature, supervisions) - # nnet_output is [N, T, C] + # nnet_output is (N, T, C) # NOTE: We need `encode_supervisions` to sort sequences with # different duration in decreasing order, required by diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index 88b10b23d..68a4ff65c 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -83,8 +83,8 @@ class Transformer(nn.Module): if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") - # self.encoder_embed converts the input of shape [N, T, num_classes] - # to the shape [N, T//subsampling_factor, d_model]. + # self.encoder_embed converts the input of shape (N, T, num_classes) + # to the shape (N, T//subsampling_factor, d_model). # That is, it does two things simultaneously: # (1) subsampling: T -> T//subsampling_factor # (2) embedding: num_classes -> d_model @@ -162,7 +162,7 @@ class Transformer(nn.Module): """ Args: x: - The input tensor. Its shape is [N, T, C]. + The input tensor. Its shape is (N, T, C). supervision: Supervision in lhotse format. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa @@ -171,17 +171,17 @@ class Transformer(nn.Module): Returns: Return a tuple containing 3 tensors: - - CTC output for ctc decoding. Its shape is [N, T, C] - - Encoder output with shape [T, N, C]. It can be used as key and + - CTC output for ctc decoding. Its shape is (N, T, C) + - Encoder output with shape (T, N, C). It can be used as key and value for the decoder. - Encoder output padding mask. It can be used as - memory_key_padding_mask for the decoder. Its shape is [N, T]. + memory_key_padding_mask for the decoder. Its shape is (N, T). It is None if `supervision` is None. """ if self.use_feat_batchnorm: - x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T] + x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) - x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C] + x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) encoder_memory, memory_key_padding_mask = self.run_encoder( x, supervision ) @@ -195,7 +195,7 @@ class Transformer(nn.Module): Args: x: - The model input. Its shape is [N, T, C]. + The model input. Its shape is (N, T, C). supervisions: Supervision in lhotse format. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa @@ -206,8 +206,8 @@ class Transformer(nn.Module): padding mask for the decoder. Returns: Return a tuple with two tensors: - - The encoder output, with shape [T, N, C] - - encoder padding mask, with shape [N, T]. + - The encoder output, with shape (T, N, C) + - encoder padding mask, with shape (N, T). The mask is None if `supervisions` is None. It is used as memory key padding mask in the decoder. """ @@ -225,11 +225,11 @@ class Transformer(nn.Module): Args: x: The output tensor from the transformer encoder. - Its shape is [T, N, C] + Its shape is (T, N, C) Returns: Return a tensor that can be used for CTC decoding. - Its shape is [N, T, C] + Its shape is (N, T, C) """ x = self.encoder_output_layer(x) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) @@ -247,7 +247,7 @@ class Transformer(nn.Module): """ Args: memory: - It's the output of the encoder with shape [T, N, C] + It's the output of the encoder with shape (T, N, C) memory_key_padding_mask: The padding mask from the encoder. token_ids: @@ -312,7 +312,7 @@ class Transformer(nn.Module): """ Args: memory: - It's the output of the encoder with shape [T, N, C] + It's the output of the encoder with shape (T, N, C) memory_key_padding_mask: The padding mask from the encoder. token_ids: @@ -654,13 +654,13 @@ class PositionalEncoding(nn.Module): def extend_pe(self, x: torch.Tensor) -> None: """Extend the time t in the positional encoding if required. - The shape of `self.pe` is [1, T1, d_model]. The shape of the input x - is [N, T, d_model]. If T > T1, then we change the shape of self.pe - to [N, T, d_model]. Otherwise, nothing is done. + The shape of `self.pe` is (1, T1, d_model). The shape of the input x + is (N, T, d_model). If T > T1, then we change the shape of self.pe + to (N, T, d_model). Otherwise, nothing is done. Args: x: - It is a tensor of shape [N, T, C]. + It is a tensor of shape (N, T, C). Returns: Return None. """ @@ -678,7 +678,7 @@ class PositionalEncoding(nn.Module): pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) - # Now pe is of shape [1, T, d_model], where T is x.size(1) + # Now pe is of shape (1, T, d_model), where T is x.size(1) self.pe = pe.to(device=x.device, dtype=x.dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -687,10 +687,10 @@ class PositionalEncoding(nn.Module): Args: x: - Its shape is [N, T, C] + Its shape is (N, T, C) Returns: - Return a tensor of shape [N, T, C] + Return a tensor of shape (N, T, C) """ self.extend_pe(x) x = x * self.xscale + self.pe[:, : x.size(1), :] diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 7e5ec8c0d..1e91b1008 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -190,12 +190,12 @@ def decode_one_batch( feature = batch["inputs"] assert feature.ndim == 3 feature = feature.to(device) - # at entry, feature is [N, T, C] + # at entry, feature is (N, T, C) - feature = feature.permute(0, 2, 1) # now feature is [N, C, T] + feature = feature.permute(0, 2, 1) # now feature is (N, C, T) nnet_output = model(feature) - # nnet_output is [N, T, C] + # nnet_output is (N, T, C) supervisions = batch["supervisions"] @@ -229,6 +229,7 @@ def decode_one_batch( lattice=lattice, num_paths=params.num_paths, use_double_scores=params.use_double_scores, + lattice_score_scale=params.lattice_score_scale, ) key = f"no_rescore-{params.num_paths}" hyps = get_texts(best_path) @@ -247,10 +248,13 @@ def decode_one_batch( G=G, num_paths=params.num_paths, lm_scale_list=lm_scale_list, + lattice_score_scale=params.lattice_score_scale, ) else: best_path_dict = rescore_with_whole_lattice( - lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, ) ans = dict() diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py index 4f82a989c..0a543d859 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py @@ -218,11 +218,11 @@ def main(): features = pad_sequence( features, batch_first=True, padding_value=math.log(1e-10) ) - features = features.permute(0, 2, 1) # now features is [N, C, T] + features = features.permute(0, 2, 1) # now features is (N, C, T) with torch.no_grad(): nnet_output = model(features) - # nnet_output is [N, T, C] + # nnet_output is (N, T, C) batch_size = nnet_output.shape[0] supervision_segments = torch.tensor( diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index 4d45d197b..695ee5130 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -290,14 +290,14 @@ def compute_loss( """ device = graph_compiler.device feature = batch["inputs"] - # at entry, feature is [N, T, C] - feature = feature.permute(0, 2, 1) # now feature is [N, C, T] + # at entry, feature is (N, T, C) + feature = feature.permute(0, 2, 1) # now feature is (N, C, T) assert feature.ndim == 3 feature = feature.to(device) with torch.set_grad_enabled(is_training): nnet_output = model(feature) - # nnet_output is [N, T, C] + # nnet_output is (N, T, C) # NOTE: We need `encode_supervisions` to sort sequences with # different duration in decreasing order, required by diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py index 54fdbb3cc..325acf316 100755 --- a/egs/yesno/ASR/tdnn/decode.py +++ b/egs/yesno/ASR/tdnn/decode.py @@ -111,10 +111,10 @@ def decode_one_batch( feature = batch["inputs"] assert feature.ndim == 3 feature = feature.to(device) - # at entry, feature is [N, T, C] + # at entry, feature is (N, T, C) nnet_output = model(feature) - # nnet_output is [N, T, C] + # nnet_output is (N, T, C) batch_size = nnet_output.shape[0] supervision_segments = torch.tensor( diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py index 39c5ef3ef..0f5506d38 100755 --- a/egs/yesno/ASR/tdnn/train.py +++ b/egs/yesno/ASR/tdnn/train.py @@ -268,13 +268,13 @@ def compute_loss( """ device = graph_compiler.device feature = batch["inputs"] - # at entry, feature is [N, T, C] + # at entry, feature is (N, T, C) assert feature.ndim == 3 feature = feature.to(device) with torch.set_grad_enabled(is_training): nnet_output = model(feature) - # nnet_output is [N, T, C] + # nnet_output is (N, T, C) # NOTE: We need `encode_supervisions` to sort sequences with # different duration in decreasing order, required by diff --git a/icefall/decode.py b/icefall/decode.py index 29b76d973..e678e4622 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -15,42 +15,12 @@ # limitations under the License. import logging -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Union import k2 -import kaldialign import torch -import torch.nn as nn - -def _get_random_paths( - lattice: k2.Fsa, - num_paths: int, - use_double_scores: bool = True, - scale: float = 1.0, -): - """ - Args: - lattice: - The decoding lattice, returned by :func:`get_lattice`. - num_paths: - It specifies the size `n` in n-best. Note: Paths are selected randomly - and those containing identical word sequences are remove dand only one - of them is kept. - use_double_scores: - True to use double precision floating point in the computation. - False to use single precision. - scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - Returns: - Return a k2.RaggedInt with 3 axes [seq][path][arc_pos] - """ - saved_scores = lattice.scores.clone() - lattice.scores *= scale - path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True) - lattice.scores = saved_scores - return path +from icefall.utils import get_texts def _intersect_device( @@ -65,7 +35,7 @@ def _intersect_device( CUDA OOM error. The arguments and return value of this function are the same as - k2.intersect_device. + :func:`k2.intersect_device`. """ num_fsas = b_fsas.shape[0] if num_fsas <= batch_size: @@ -106,10 +76,9 @@ def get_lattice( ) -> k2.Fsa: """Get the decoding lattice from a decoding graph and neural network output. - Args: nnet_output: - It is the output of a neural model of shape `[N, T, C]`. + It is the output of a neural model of shape `(N, T, C)`. HLG: An Fsa, the decoding graph. See also `compile_HLG.py`. supervision_segments: @@ -139,10 +108,12 @@ def get_lattice( subsampling_factor: The subsampling factor of the model. Returns: - A lattice containing the decoding result. + An FsaVec containing the decoding result. It has axes [utt][state][arc]. """ dense_fsa_vec = k2.DenseFsaVec( - nnet_output, supervision_segments, allow_truncate=subsampling_factor - 1 + nnet_output, + supervision_segments, + allow_truncate=subsampling_factor - 1, ) lattice = k2.intersect_dense_pruned( @@ -157,8 +128,304 @@ def get_lattice( return lattice +class Nbest(object): + """ + An Nbest object contains two fields: + + (1) fsa. It is an FsaVec containing a vector of **linear** FSAs. + Its axes are [path][state][arc] + (2) shape. Its type is :class:`k2.RaggedShape`. + Its axes are [utt][path] + + The field `shape` has two axes [utt][path]. `shape.dim0` contains + the number of utterances, which is also the number of rows in the + supervision_segments. `shape.tot_size(1)` contains the number + of paths, which is also the number of FSAs in `fsa`. + + Caution: + Don't be confused by the name `Nbest`. The best in the name `Nbest` + has nothing to do with `best scores`. The important part is + `N` in `Nbest`, not `best`. + """ + + def __init__(self, fsa: k2.Fsa, shape: k2.RaggedShape) -> None: + """ + Args: + fsa: + An FsaVec with axes [path][state][arc]. It is expected to contain + a list of **linear** FSAs. + shape: + A ragged shape with two axes [utt][path]. + """ + assert len(fsa.shape) == 3, f"fsa.shape: {fsa.shape}" + assert shape.num_axes == 2, f"num_axes: {shape.num_axes}" + + if fsa.shape[0] != shape.tot_size(1): + raise ValueError( + f"{fsa.shape[0]} vs {shape.tot_size(1)}\n" + "Number of FSAs in `fsa` does not match the given shape" + ) + + self.fsa = fsa + self.shape = shape + + def __str__(self): + s = "Nbest(" + s += f"Number of utterances:{self.shape.dim0}, " + s += f"Number of Paths:{self.fsa.shape[0]})" + return s + + @staticmethod + def from_lattice( + lattice: k2.Fsa, + num_paths: int, + use_double_scores: bool = True, + lattice_score_scale: float = 0.5, + ) -> "Nbest": + """Construct an Nbest object by **sampling** `num_paths` from a lattice. + + Each sampled path is a linear FSA. + + We assume `lattice.labels` contains token IDs and `lattice.aux_labels` + contains word IDs. + + Args: + lattice: + An FsaVec with axes [utt][state][arc]. + num_paths: + Number of paths to **sample** from the lattice + using :func:`k2.random_paths`. + use_double_scores: + True to use double precision in :func:`k2.random_paths`. + False to use single precision. + scale: + Scale `lattice.score` before passing it to :func:`k2.random_paths`. + A smaller value leads to more unique paths at the risk of being not + to sample the path with the best score. + Returns: + Return an Nbest instance. + """ + saved_scores = lattice.scores.clone() + lattice.scores *= lattice_score_scale + # path is a ragged tensor with dtype torch.int32. + # It has three axes [utt][path][arc_pos] + path = k2.random_paths( + lattice, num_paths=num_paths, use_double_scores=use_double_scores + ) + lattice.scores = saved_scores + + # word_seq is a k2.RaggedTensor sharing the same shape as `path` + # but it contains word IDs. Note that it also contains 0s and -1s. + # The last entry in each sublist is -1. + # It axes is [utt][path][word_id] + if isinstance(lattice.aux_labels, torch.Tensor): + word_seq = k2.ragged.index(lattice.aux_labels, path) + else: + word_seq = lattice.aux_labels.index(path) + word_seq = word_seq.remove_axis(word_seq.num_axes - 2) + + # Each utterance has `num_paths` paths but some of them transduces + # to the same word sequence, so we need to remove repeated word + # sequences within an utterance. After removing repeats, each utterance + # contains different number of paths + # + # `new2old` is a 1-D torch.Tensor mapping from the output path index + # to the input path index. + _, _, new2old = word_seq.unique( + need_num_repeats=False, need_new2old_indexes=True + ) + + # kept_path is a ragged tensor with dtype torch.int32. + # It has axes [utt][path][arc_pos] + kept_path, _ = path.index(new2old, axis=1, need_value_indexes=False) + + # utt_to_path_shape has axes [utt][path] + utt_to_path_shape = kept_path.shape.get_layer(0) + + # Remove the utterance axis. + # Now kept_path has only two axes [path][arc_pos] + kept_path = kept_path.remove_axis(0) + + # labels is a ragged tensor with 2 axes [path][token_id] + # Note that it contains -1s. + labels = k2.ragged.index(lattice.labels.contiguous(), kept_path) + + # Remove -1 from labels as we will use it to construct a linear FSA + labels = labels.remove_values_eq(-1) + + if isinstance(lattice.aux_labels, k2.RaggedTensor): + # lattice.aux_labels is a ragged tensor with dtype torch.int32. + # It has 2 axes [arc][word], so aux_labels is also a ragged tensor + # with 2 axes [arc][word] + aux_labels, _ = lattice.aux_labels.index( + indexes=kept_path.values, axis=0, need_value_indexes=False + ) + else: + assert isinstance(lattice.aux_labels, torch.Tensor) + aux_labels = k2.index_select(lattice.aux_labels, kept_path.values) + # aux_labels is a 1-D torch.Tensor. It also contains -1 and 0. + + fsa = k2.linear_fsa(labels) + fsa.aux_labels = aux_labels + # Caution: fsa.scores are all 0s. + # `fsa` has only one extra attribute: aux_labels. + return Nbest(fsa=fsa, shape=utt_to_path_shape) + + def intersect(self, lattice: k2.Fsa, use_double_scores=True) -> "Nbest": + """Intersect this Nbest object with a lattice, get 1-best + path from the resulting FsaVec, and return a new Nbest object. + + The purpose of this function is to attach scores to an Nbest. + + Args: + lattice: + An FsaVec with axes [utt][state][arc]. If it has `aux_labels`, then + we assume its `labels` are token IDs and `aux_labels` are word IDs. + If it has only `labels`, we assume its `labels` are word IDs. + use_double_scores: + True to use double precision when computing shortest path. + False to use single precision. + Returns: + Return a new Nbest. This new Nbest shares the same shape with `self`, + while its `fsa` is the 1-best path from intersecting `self.fsa` and + `lattice`. Also, its `fsa` has non-zero scores and inherits attributes + for `lattice`. + """ + # Note: We view each linear FSA as a word sequence + # and we use the passed lattice to give each word sequence a score. + # + # We are not viewing each linear FSAs as a token sequence. + # + # So we use k2.invert() here. + + # We use a word fsa to intersect with k2.invert(lattice) + word_fsa = k2.invert(self.fsa) + + if hasattr(lattice, "aux_labels"): + # delete token IDs as it is not needed + del word_fsa.aux_labels + + word_fsa.scores.zero_() + word_fsa_with_epsilon_loops = k2.remove_epsilon_and_add_self_loops( + word_fsa + ) + + path_to_utt_map = self.shape.row_ids(1) + + if hasattr(lattice, "aux_labels"): + # lattice has token IDs as labels and word IDs as aux_labels. + # inv_lattice has word IDs as labels and token IDs as aux_labels + inv_lattice = k2.invert(lattice) + inv_lattice = k2.arc_sort(inv_lattice) + else: + inv_lattice = k2.arc_sort(lattice) + + if inv_lattice.shape[0] == 1: + path_lattice = _intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=torch.zeros_like(path_to_utt_map), + sorted_match_a=True, + ) + else: + path_lattice = _intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=path_to_utt_map, + sorted_match_a=True, + ) + + # path_lattice has word IDs as labels and token IDs as aux_labels + path_lattice = k2.top_sort(k2.connect(path_lattice)) + + one_best = k2.shortest_path( + path_lattice, use_double_scores=use_double_scores + ) + + one_best = k2.invert(one_best) + # Now one_best has token IDs as labels and word IDs as aux_labels + + return Nbest(fsa=one_best, shape=self.shape) + + def compute_am_scores(self) -> k2.RaggedTensor: + """Compute AM scores of each linear FSA (i.e., each path within + an utterance). + + Hint: + `self.fsa.scores` contains two parts: acoustic scores (AM scores) + and n-gram language model scores (LM scores). + + Caution: + We require that ``self.fsa`` has an attribute ``lm_scores``. + + Returns: + Return a ragged tensor with 2 axes [utt][path_scores]. + Its dtype is torch.float64. + """ + saved_scores = self.fsa.scores + + # The `scores` of every arc consists of `am_scores` and `lm_scores` + self.fsa.scores = self.fsa.scores - self.fsa.lm_scores + + am_scores = self.fsa.get_tot_scores( + use_double_scores=True, log_semiring=False + ) + self.fsa.scores = saved_scores + + return k2.RaggedTensor(self.shape, am_scores) + + def compute_lm_scores(self) -> k2.RaggedTensor: + """Compute LM scores of each linear FSA (i.e., each path within + an utterance). + + Hint: + `self.fsa.scores` contains two parts: acoustic scores (AM scores) + and n-gram language model scores (LM scores). + + Caution: + We require that ``self.fsa`` has an attribute ``lm_scores``. + + Returns: + Return a ragged tensor with 2 axes [utt][path_scores]. + Its dtype is torch.float64. + """ + saved_scores = self.fsa.scores + + # The `scores` of every arc consists of `am_scores` and `lm_scores` + self.fsa.scores = self.fsa.lm_scores + + lm_scores = self.fsa.get_tot_scores( + use_double_scores=True, log_semiring=False + ) + self.fsa.scores = saved_scores + + return k2.RaggedTensor(self.shape, lm_scores) + + def tot_scores(self) -> k2.RaggedTensor: + """Get total scores of FSAs in this Nbest. + + Note: + Since FSAs in Nbest are just linear FSAs, log-semiring + and tropical semiring produce the same total scores. + + Returns: + Return a ragged tensor with two axes [utt][path_scores]. + Its dtype is torch.float64. + """ + scores = self.fsa.get_tot_scores( + use_double_scores=True, log_semiring=False + ) + return k2.RaggedTensor(self.shape, scores) + + def build_levenshtein_graphs(self) -> k2.Fsa: + """Return an FsaVec with axes [utt][state][arc].""" + word_ids = get_texts(self.fsa, return_ragged=True) + return k2.levenshtein_graph(word_ids) + + def one_best_decoding( - lattice: k2.Fsa, use_double_scores: bool = True + lattice: k2.Fsa, + use_double_scores: bool = True, ) -> k2.Fsa: """Get the best path from a lattice. @@ -179,199 +446,143 @@ def nbest_decoding( lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True, - scale: float = 1.0, + lattice_score_scale: float = 1.0, ) -> k2.Fsa: """It implements something like CTC prefix beam search using n-best lists. - The basic idea is to first extra n-best paths from the given lattice, - build a word seqs from these paths, and compute the total scores - of these sequences in the log-semiring. The one with the max score + The basic idea is to first extract `num_paths` paths from the given lattice, + build a word sequence from these paths, and compute the total scores + of the word sequence in the tropical semiring. The one with the max score is used as the decoding output. Caution: Don't be confused by `best` in the name `n-best`. Paths are selected - randomly, not by ranking their scores. + **randomly**, not by ranking their scores. + + Hint: + This decoding method is for demonstration only and it does + not produce a lower WER than :func:`one_best_decoding`. Args: lattice: - The decoding lattice, returned by :func:`get_lattice`. + The decoding lattice, e.g., can be the return value of + :func:`get_lattice`. It has 3 axes [utt][state][arc]. num_paths: It specifies the size `n` in n-best. Note: Paths are selected randomly - and those containing identical word sequences are remove dand only one + and those containing identical word sequences are removed and only one of them is kept. use_double_scores: True to use double precision floating point in the computation. False to use single precision. - scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. + lattice_score_scale: + It's the scale applied to the `lattice.scores`. A smaller value + leads to more unique paths at the risk of missing the correct path. Returns: - An FsaVec containing linear FSAs. + An FsaVec containing **linear** FSAs. It axes are [utt][state][arc]. """ - path = _get_random_paths( + nbest = Nbest.from_lattice( lattice=lattice, num_paths=num_paths, use_double_scores=use_double_scores, - scale=scale, + lattice_score_scale=lattice_score_scale, ) + # nbest.fsa.scores contains 0s - # word_seq is a k2.RaggedTensor sharing the same shape as `path` - # but it contains word IDs. Note that it also contains 0s and -1s. - # The last entry in each sublist is -1. - if isinstance(lattice.aux_labels, torch.Tensor): - word_seq = k2.ragged.index(lattice.aux_labels, path) - else: - word_seq = lattice.aux_labels.index(path) - word_seq = word_seq.remove_axis(word_seq.num_axes - 2) + nbest = nbest.intersect(lattice) + # now nbest.fsa.scores gets assigned - # Remove 0 (epsilon) and -1 from word_seq - word_seq = word_seq.remove_values_leq(0) + # max_indexes contains the indexes for the path with the maximum score + # within an utterance. + max_indexes = nbest.tot_scores().argmax() - # Remove sequences with identical word sequences. - # - # k2.ragged.unique_sequences will reorder paths within a seq. - # `new2old` is a 1-D torch.Tensor mapping from the output path index - # to the input path index. - # new2old.numel() == unique_word_seqs.tot_size(1) - unique_word_seq, _, new2old = word_seq.unique( - need_num_repeats=False, need_new2old_indexes=True - ) - # Note: unique_word_seq still has the same axes as word_seq - - seq_to_path_shape = unique_word_seq.shape.get_layer(0) - - # path_to_seq_map is a 1-D torch.Tensor. - # path_to_seq_map[i] is the seq to which the i-th path belongs - path_to_seq_map = seq_to_path_shape.row_ids(1) - - # Remove the seq axis. - # Now unique_word_seq has only two axes [path][word] - unique_word_seq = unique_word_seq.remove_axis(0) - - # word_fsa is an FsaVec with axes [path][state][arc] - word_fsa = k2.linear_fsa(unique_word_seq) - - # add epsilon self loops since we will use - # k2.intersect_device, which treats epsilon as a normal symbol - word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa) - - # lattice has token IDs as labels and word IDs as aux_labels. - # inv_lattice has word IDs as labels and token IDs as aux_labels - inv_lattice = k2.invert(lattice) - inv_lattice = k2.arc_sort(inv_lattice) - - path_lattice = _intersect_device( - inv_lattice, - word_fsa_with_epsilon_loops, - b_to_a_map=path_to_seq_map, - sorted_match_a=True, - ) - # path_lat has word IDs as labels and token IDs as aux_labels - - path_lattice = k2.top_sort(k2.connect(path_lattice)) - - tot_scores = path_lattice.get_tot_scores( - use_double_scores=use_double_scores, log_semiring=False - ) - - ragged_tot_scores = k2.RaggedTensor(seq_to_path_shape, tot_scores) - - argmax_indexes = ragged_tot_scores.argmax() - - # Since we invoked `k2.ragged.unique_sequences`, which reorders - # the index from `path`, we use `new2old` here to convert argmax_indexes - # to the indexes into `path`. - # - # Use k2.index here since argmax_indexes' dtype is torch.int32 - best_path_indexes = k2.index_select(new2old, argmax_indexes) - - path_2axes = path.remove_axis(0) - - # best_path is a k2.RaggedTensor with 2 axes [path][arc_pos] - best_path, _ = path_2axes.index( - indexes=best_path_indexes, axis=0, need_value_indexes=False - ) - - # labels is a k2.RaggedTensor with 2 axes [path][token_id] - # Note that it contains -1s. - labels = k2.ragged.index(lattice.labels.contiguous(), best_path) - - labels = labels.remove_values_eq(-1) - - # lattice.aux_labels is a k2.RaggedTensor with 2 axes, so - # aux_labels is also a k2.RaggedTensor with 2 axes - aux_labels, _ = lattice.aux_labels.index( - indexes=best_path.values, axis=0, need_value_indexes=False - ) - - best_path_fsa = k2.linear_fsa(labels) - best_path_fsa.aux_labels = aux_labels - return best_path_fsa + best_path = k2.index_fsa(nbest.fsa, max_indexes) + return best_path -def compute_am_and_lm_scores( +def nbest_oracle( lattice: k2.Fsa, - word_fsa_with_epsilon_loops: k2.Fsa, - path_to_seq_map: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Compute AM scores of n-best lists (represented as word_fsas). + num_paths: int, + ref_texts: List[str], + word_table: k2.SymbolTable, + use_double_scores: bool = True, + lattice_score_scale: float = 0.5, + oov: str = "", +) -> Dict[str, List[List[int]]]: + """Select the best hypothesis given a lattice and a reference transcript. + + The basic idea is to extract `num_paths` paths from the given lattice, + unique them, and select the one that has the minimum edit distance with + the corresponding reference transcript as the decoding output. + + The decoding result returned from this function is the best result that + we can obtain using n-best decoding with all kinds of rescoring techniques. + + This function is useful to tune the value of `lattice_score_scale`. Args: lattice: - An FsaVec, e.g., the return value of :func:`get_lattice` - It must have the attribute `lm_scores`. - word_fsa_with_epsilon_loops: - An FsaVec representing an n-best list. Note that it has been processed - by `k2.add_epsilon_self_loops`. - path_to_seq_map: - A 1-D torch.Tensor with dtype torch.int32. path_to_seq_map[i] indicates - which sequence the i-th Fsa in word_fsa_with_epsilon_loops belongs to. - path_to_seq_map.numel() == word_fsas_with_epsilon_loops.arcs.dim0(). - Returns: - Return a tuple containing two 1-D torch.Tensors: (am_scores, lm_scores). - Each tensor's `numel()' equals to `word_fsas_with_epsilon_loops.shape[0]` + An FsaVec with axes [utt][state][arc]. + Note: We assume its `aux_labels` contains word IDs. + num_paths: + The size of `n` in n-best. + ref_texts: + A list of reference transcript. Each entry contains space(s) + separated words + word_table: + It is the word symbol table. + use_double_scores: + True to use double precision for computation. False to use + single precision. + lattice_score_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + oov: + The out of vocabulary word. + Return: + Return a dict. Its key contains the information about the parameters + when calling this function, while its value contains the decoding output. + `len(ans_dict) == len(ref_texts)` """ - assert len(lattice.shape) == 3 - assert hasattr(lattice, "lm_scores") + device = lattice.device - # k2.compose() currently does not support b_to_a_map. To void - # replicating `lats`, we use k2.intersect_device here. - # - # lattice has token IDs as `labels` and word IDs as aux_labels, so we - # need to invert it here. - inv_lattice = k2.invert(lattice) - - # Now the `labels` of inv_lattice are word IDs (a 1-D torch.Tensor) - # and its `aux_labels` are token IDs ( a k2.RaggedInt with 2 axes) - - # Remove its `aux_labels` since it is not needed in the - # following computation - del inv_lattice.aux_labels - inv_lattice = k2.arc_sort(inv_lattice) - - path_lattice = _intersect_device( - inv_lattice, - word_fsa_with_epsilon_loops, - b_to_a_map=path_to_seq_map, - sorted_match_a=True, + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + lattice_score_scale=lattice_score_scale, ) - path_lattice = k2.top_sort(k2.connect(path_lattice)) + hyps = nbest.build_levenshtein_graphs() - # The `scores` of every arc consists of `am_scores` and `lm_scores` - path_lattice.scores = path_lattice.scores - path_lattice.lm_scores + oov_id = word_table[oov] + word_ids_list = [] + for text in ref_texts: + word_ids = [] + for word in text.split(): + if word in word_table: + word_ids.append(word_table[word]) + else: + word_ids.append(oov_id) + word_ids_list.append(word_ids) - am_scores = path_lattice.get_tot_scores( - use_double_scores=True, log_semiring=False + refs = k2.levenshtein_graph(word_ids_list, device=device) + + levenshtein_alignment = k2.levenshtein_alignment( + refs=refs, + hyps=hyps, + hyp_to_ref_map=nbest.shape.row_ids(1), + sorted_match_ref=True, ) - path_lattice.scores = path_lattice.lm_scores - - lm_scores = path_lattice.get_tot_scores( - use_double_scores=True, log_semiring=False + tot_scores = levenshtein_alignment.get_tot_scores( + use_double_scores=False, log_semiring=False ) + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - return am_scores.to(torch.float32), lm_scores.to(torch.float32) + max_indexes = ragged_tot_scores.argmax() + + best_path = k2.index_fsa(nbest.fsa, max_indexes) + return best_path def rescore_with_n_best_list( @@ -379,34 +590,32 @@ def rescore_with_n_best_list( G: k2.Fsa, num_paths: int, lm_scale_list: List[float], - scale: float = 1.0, + lattice_score_scale: float = 1.0, + use_double_scores: bool = True, ) -> Dict[str, k2.Fsa]: - """Decode using n-best list with LM rescoring. - - `lattice` is a decoding lattice with 3 axes. This function first - extracts `num_paths` paths from `lattice` for each sequence using - `k2.random_paths`. The `am_scores` of these paths are computed. - For each path, its `lm_scores` is computed using `G` (which is an LM). - The final `tot_scores` is the sum of `am_scores` and `lm_scores`. - The path with the largest `tot_scores` within a sequence is used - as the decoding output. + """Rescore an n-best list with an n-gram LM. + The path with the maximum score is used as the decoding output. Args: lattice: - An FsaVec. It can be the return value of :func:`get_lattice`. + An FsaVec with axes [utt][state][arc]. It must have the following + attributes: ``aux_labels`` and ``lm_scores``. Its labels are + token IDs and ``aux_labels`` word IDs. G: - An FsaVec representing the language model (LM). Note that it - is an FsaVec, but it contains only one Fsa. + An FsaVec containing only a single FSA. It is an n-gram LM. num_paths: - It is the size `n` in `n-best` list. + Size of nbest list. lm_scale_list: - A list containing lm_scale values. - scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. + A list of float representing LM score scales. + lattice_score_scale: + Scale to be applied to ``lattice.score`` when sampling paths + using ``k2.random_paths``. + use_double_scores: + True to use double precision during computation. False to use + single precision. Returns: A dict of FsaVec, whose key is an lm_scale and the value is the - best decoding path for each sequence in the lattice. + best decoding path for each utterance in the lattice. """ device = lattice.device @@ -418,119 +627,32 @@ def rescore_with_n_best_list( assert G.device == device assert hasattr(G, "aux_labels") is False - path = _get_random_paths( + nbest = Nbest.from_lattice( lattice=lattice, num_paths=num_paths, - use_double_scores=True, - scale=scale, + use_double_scores=use_double_scores, + lattice_score_scale=lattice_score_scale, ) + # nbest.fsa.scores are all 0s at this point - # word_seq is a k2.RaggedTensor sharing the same shape as `path` - # but it contains word IDs. Note that it also contains 0s and -1s. - # The last entry in each sublist is -1. - if isinstance(lattice.aux_labels, torch.Tensor): - word_seq = k2.ragged.index(lattice.aux_labels, path) - else: - word_seq = lattice.aux_labels.index(path) - word_seq = word_seq.remove_axis(word_seq.num_axes - 2) + nbest = nbest.intersect(lattice) + # Now nbest.fsa has its scores set + assert hasattr(nbest.fsa, "lm_scores") - # Remove epsilons and -1 from word_seq - word_seq = word_seq.remove_values_leq(0) + am_scores = nbest.compute_am_scores() - # Remove paths that has identical word sequences. - # - # unique_word_seq is still a k2.RaggedTensor with 3 axes [seq][path][word] - # except that there are no repeated paths with the same word_seq - # within a sequence. - # - # num_repeats is also a k2.RaggedTensor with 2 axes containing the - # multiplicities of each path. - # num_repeats.numel() == unique_word_seqs.tot_size(1) - # - # Since k2.ragged.unique_sequences will reorder paths within a seq, - # `new2old` is a 1-D torch.Tensor mapping from the output path index - # to the input path index. - # new2old.numel() == unique_word_seqs.tot_size(1) - unique_word_seq, num_repeats, new2old = word_seq.unique( - need_num_repeats=True, need_new2old_indexes=True - ) - - seq_to_path_shape = unique_word_seq.shape.get_layer(0) - - # path_to_seq_map is a 1-D torch.Tensor. - # path_to_seq_map[i] is the seq to which the i-th path - # belongs. - path_to_seq_map = seq_to_path_shape.row_ids(1) - - # Remove the seq axis. - # Now unique_word_seq has only two axes [path][word] - unique_word_seq = unique_word_seq.remove_axis(0) - - # word_fsa is an FsaVec with axes [path][state][arc] - word_fsa = k2.linear_fsa(unique_word_seq) - - word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa) - - am_scores, _ = compute_am_and_lm_scores( - lattice, word_fsa_with_epsilon_loops, path_to_seq_map - ) - - # Now compute lm_scores - b_to_a_map = torch.zeros_like(path_to_seq_map) - lm_path_lattice = _intersect_device( - G, - word_fsa_with_epsilon_loops, - b_to_a_map=b_to_a_map, - sorted_match_a=True, - ) - lm_path_lattice = k2.top_sort(k2.connect(lm_path_lattice)) - lm_scores = lm_path_lattice.get_tot_scores( - use_double_scores=True, log_semiring=False - ) - - path_2axes = path.remove_axis(0) + nbest = nbest.intersect(G) + # Now nbest contains only lm scores + lm_scores = nbest.tot_scores() ans = dict() for lm_scale in lm_scale_list: - tot_scores = am_scores / lm_scale + lm_scores - - # Remember that we used `k2.RaggedTensor.unique` to remove repeated - # paths to avoid redundant computation in `k2.intersect_device`. - # Now we use `num_repeats` to correct the scores for each path. - # - # NOTE(fangjun): It is commented out as it leads to a worse WER - # tot_scores = tot_scores * num_repeats.values() - - ragged_tot_scores = k2.RaggedTensor(seq_to_path_shape, tot_scores) - argmax_indexes = ragged_tot_scores.argmax() - - # Use k2.index here since argmax_indexes' dtype is torch.int32 - best_path_indexes = k2.index_select(new2old, argmax_indexes) - - # best_path is a k2.RaggedInt with 2 axes [path][arc_pos] - best_path, _ = path_2axes.index( - indexes=best_path_indexes, axis=0, need_value_indexes=False - ) - - # labels is a k2.RaggedTensor with 2 axes [path][phone_id] - # Note that it contains -1s. - labels = k2.ragged.index(lattice.labels.contiguous(), best_path) - - labels = labels.remove_values_eq(-1) - - # lattice.aux_labels is a k2.RaggedTensor tensor with 2 axes, so - # aux_labels is also a k2.RaggedTensor with 2 axes - - aux_labels, _ = lattice.aux_labels.index( - indexes=best_path.values, axis=0, need_value_indexes=False - ) - - best_path_fsa = k2.linear_fsa(labels) - best_path_fsa.aux_labels = aux_labels - + tot_scores = am_scores.values / lm_scale + lm_scores.values + tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + max_indexes = tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, max_indexes) key = f"lm_scale_{lm_scale}" - ans[key] = best_path_fsa - + ans[key] = best_path return ans @@ -538,25 +660,40 @@ def rescore_with_whole_lattice( lattice: k2.Fsa, G_with_epsilon_loops: k2.Fsa, lm_scale_list: Optional[List[float]] = None, + use_double_scores: bool = True, ) -> Union[k2.Fsa, Dict[str, k2.Fsa]]: - """Use whole lattice to rescore. + """Intersect the lattice with an n-gram LM and use shortest path + to decode. + + The input lattice is obtained by intersecting `HLG` with + a DenseFsaVec, where the `G` in `HLG` is in general a 3-gram LM. + The input `G_with_epsilon_loops` is usually a 4-gram LM. You can consider + this function as a second pass decoding. In the first pass decoding, we + use a small G, while we use a larger G in the second pass decoding. Args: lattice: - An FsaVec It can be the return value of :func:`get_lattice`. + An FsaVec with axes [utt][state][arc]. Its `aux_lables` are word IDs. + It must have an attribute `lm_scores`. G_with_epsilon_loops: - An FsaVec representing the language model (LM). Note that it - is an FsaVec, but it contains only one Fsa. + An FsaVec containing only a single FSA. It contains epsilon self-loops. + It is an acceptor and its labels are word IDs. lm_scale_list: - A list containing lm_scale values or None. + Optional. If none, return the intersection of `lattice` and + `G_with_epsilon_loops`. + If not None, it contains a list of values to scale LM scores. + For each scale, there is a corresponding decoding result contained in + the resulting dict. + use_double_scores: + True to use double precision in the computation. + False to use single precision. Returns: - If lm_scale_list is not None, return a dict of FsaVec, whose key - is a lm_scale and the value represents the best decoding path for - each sequence in the lattice. - If lm_scale_list is not None, return a lattice that is rescored - with the given LM. + If `lm_scale_list` is None, return a new lattice which is the intersection + result of `lattice` and `G_with_epsilon_loops`. + Otherwise, return a dict whose key is an entry in `lm_scale_list` and the + value is the decoding result (i.e., an FsaVec containing linear FSAs). """ - assert len(lattice.shape) == 3 + # Nbest is not used in this function assert hasattr(lattice, "lm_scores") assert G_with_epsilon_loops.shape == (1, None, None) @@ -564,19 +701,22 @@ def rescore_with_whole_lattice( lattice.scores = lattice.scores - lattice.lm_scores # We will use lm_scores from G, so remove lats.lm_scores here del lattice.lm_scores - assert hasattr(lattice, "lm_scores") is False assert hasattr(G_with_epsilon_loops, "lm_scores") # Now, lattice.scores contains only am_scores # inv_lattice has word IDs as labels. - # Its aux_labels are token IDs, which is a ragged tensor k2.RaggedInt + # Its `aux_labels` is token IDs inv_lattice = k2.invert(lattice) num_seqs = lattice.shape[0] b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32) - while True: + + max_loop_count = 10 + loop_count = 0 + while loop_count <= max_loop_count: + loop_count += 1 try: rescoring_lattice = k2.intersect_device( G_with_epsilon_loops, @@ -592,12 +732,15 @@ def rescore_with_whole_lattice( f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}" ) - # NOTE(fangjun): The choice of the threshold 1e-7 is arbitrary here - # to avoid OOM. We may need to fine tune it. - inv_lattice = k2.prune_on_arc_post(inv_lattice, 1e-7, True) + # NOTE(fangjun): The choice of the threshold 1e-9 is arbitrary here + # to avoid OOM. You may need to fine tune it. + inv_lattice = k2.prune_on_arc_post(inv_lattice, 1e-9, True) logging.info( f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}" ) + if loop_count > max_loop_count: + logging.info("Return None as the resulting lattice is too large") + return None # lat has token IDs as labels # and word IDs as aux_labels. @@ -607,117 +750,37 @@ def rescore_with_whole_lattice( return lat ans = dict() - # - # The following implements - # scores = (scores - lm_scores)/lm_scale + lm_scores - # = scores/lm_scale + lm_scores*(1 - 1/lm_scale) - # saved_am_scores = lat.scores - lat.lm_scores for lm_scale in lm_scale_list: am_scores = saved_am_scores / lm_scale lat.scores = am_scores + lat.lm_scores - best_path = k2.shortest_path(lat, use_double_scores=True) + best_path = k2.shortest_path(lat, use_double_scores=use_double_scores) key = f"lm_scale_{lm_scale}" ans[key] = best_path return ans -def nbest_oracle( - lattice: k2.Fsa, - num_paths: int, - ref_texts: List[str], - word_table: k2.SymbolTable, - scale: float = 1.0, -) -> Dict[str, List[List[int]]]: - """Select the best hypothesis given a lattice and a reference transcript. - - The basic idea is to extract n paths from the given lattice, unique them, - and select the one that has the minimum edit distance with the corresponding - reference transcript as the decoding output. - - The decoding result returned from this function is the best result that - we can obtain using n-best decoding with all kinds of rescoring techniques. - - Args: - lattice: - An FsaVec. It can be the return value of :func:`get_lattice`. - Note: We assume its aux_labels contain word IDs. - num_paths: - The size of `n` in n-best. - ref_texts: - A list of reference transcript. Each entry contains space(s) - separated words - word_table: - It is the word symbol table. - scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - Return: - Return a dict. Its key contains the information about the parameters - when calling this function, while its value contains the decoding output. - `len(ans_dict) == len(ref_texts)` - """ - path = _get_random_paths( - lattice=lattice, - num_paths=num_paths, - use_double_scores=True, - scale=scale, - ) - - if isinstance(lattice.aux_labels, torch.Tensor): - word_seq = k2.ragged.index(lattice.aux_labels, path) - else: - word_seq = lattice.aux_labels.index(path) - word_seq = word_seq.remove_axis(word_seq.num_axes - 2) - - word_seq = word_seq.remove_values_leq(0) - unique_word_seq, _, _ = word_seq.unique( - need_num_repeats=False, need_new2old_indexes=False - ) - unique_word_ids = unique_word_seq.tolist() - assert len(unique_word_ids) == len(ref_texts) - # unique_word_ids[i] contains all hypotheses of the i-th utterance - - results = [] - for hyps, ref in zip(unique_word_ids, ref_texts): - # Note hyps is a list-of-list ints - # Each sublist contains a hypothesis - ref_words = ref.strip().split() - # CAUTION: We don't convert ref_words to ref_words_ids - # since there may exist OOV words in ref_words - best_hyp_words = None - min_error = float("inf") - for hyp_words in hyps: - hyp_words = [word_table[i] for i in hyp_words] - this_error = kaldialign.edit_distance(ref_words, hyp_words)["total"] - if this_error < min_error: - min_error = this_error - best_hyp_words = hyp_words - results.append(best_hyp_words) - - return {f"nbest_{num_paths}_scale_{scale}_oracle": results} - - def rescore_with_attention_decoder( lattice: k2.Fsa, num_paths: int, - model: nn.Module, + model: torch.nn.Module, memory: torch.Tensor, memory_key_padding_mask: Optional[torch.Tensor], sos_id: int, eos_id: int, - scale: float = 1.0, + lattice_score_scale: float = 1.0, ngram_lm_scale: Optional[float] = None, attention_scale: Optional[float] = None, + use_double_scores: bool = True, ) -> Dict[str, k2.Fsa]: - """This function extracts n paths from the given lattice and uses - an attention decoder to rescore them. The path with the highest - score is used as the decoding output. + """This function extracts `num_paths` paths from the given lattice and uses + an attention decoder to rescore them. The path with the highest score is + the decoding output. Args: lattice: - An FsaVec. It can be the return value of :func:`get_lattice`. + An FsaVec with axes [utt][state][arc]. num_paths: Number of paths to extract from the given lattice for rescoring. model: @@ -726,16 +789,16 @@ def rescore_with_attention_decoder( memory: The encoder memory of the given model. It is the output of the last torch.nn.TransformerEncoder layer in the given model. - Its shape is `[T, N, C]`. + Its shape is `(T, N, C)`. memory_key_padding_mask: - The padding mask for memory with shape [N, T]. + The padding mask for memory with shape `(N, T)`. sos_id: The token ID for SOS. eos_id: The token ID for EOS. - scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. + lattice_score_scale: + It's the scale applied to `lattice.scores`. A smaller value + leads to more unique paths at the risk of missing the correct path. ngram_lm_scale: Optional. It specifies the scale for n-gram LM scores. attention_scale: @@ -743,105 +806,47 @@ def rescore_with_attention_decoder( Returns: A dict of FsaVec, whose key contains a string ngram_lm_scale_attention_scale and the value is the - best decoding path for each sequence in the lattice. + best decoding path for each utterance in the lattice. """ - # First, extract `num_paths` paths for each sequence. - # path is a k2.RaggedInt with axes [seq][path][arc_pos] - path = _get_random_paths( + nbest = Nbest.from_lattice( lattice=lattice, num_paths=num_paths, - use_double_scores=True, - scale=scale, + use_double_scores=use_double_scores, + lattice_score_scale=lattice_score_scale, ) + # nbest.fsa.scores are all 0s at this point - # word_seq is a k2.RaggedTensor sharing the same shape as `path` - # but it contains word IDs. Note that it also contains 0s and -1s. - # The last entry in each sublist is -1. - if isinstance(lattice.aux_labels, torch.Tensor): - word_seq = k2.ragged.index(lattice.aux_labels, path) - else: - word_seq = lattice.aux_labels.index(path) - word_seq = word_seq.remove_axis(word_seq.num_axes - 2) + nbest = nbest.intersect(lattice) + # Now nbest.fsa has its scores set. + # Also, nbest.fsa inherits the attributes from `lattice`. + assert hasattr(nbest.fsa, "lm_scores") - # Remove epsilons and -1 from word_seq - word_seq = word_seq.remove_values_leq(0) + am_scores = nbest.compute_am_scores() + ngram_lm_scores = nbest.compute_lm_scores() - # Remove paths that has identical word sequences. - # - # unique_word_seq is still a k2.RaggedTensor with 3 axes [seq][path][word] - # except that there are no repeated paths with the same word_seq - # within a sequence. - # - # num_repeats is also a k2.RaggedTensor with 2 axes containing the - # multiplicities of each path. - # num_repeats.numel() == unique_word_seqs.tot_size(1) - # - # Since k2.ragged.unique_sequences will reorder paths within a seq, - # `new2old` is a 1-D torch.Tensor mapping from the output path index - # to the input path index. - # new2old.numel() == unique_word_seq.tot_size(1) - unique_word_seq, num_repeats, new2old = word_seq.unique( - need_num_repeats=True, need_new2old_indexes=True - ) + # The `tokens` attribute is set inside `compile_hlg.py` + assert hasattr(nbest.fsa, "tokens") + assert isinstance(nbest.fsa.tokens, torch.Tensor) - seq_to_path_shape = unique_word_seq.shape.get_layer(0) - - # path_to_seq_map is a 1-D torch.Tensor. - # path_to_seq_map[i] is the seq to which the i-th path - # belongs. - path_to_seq_map = seq_to_path_shape.row_ids(1) - - # Remove the seq axis. - # Now unique_word_seq has only two axes [path][word] - unique_word_seq = unique_word_seq.remove_axis(0) - - # word_fsa is an FsaVec with axes [path][state][arc] - word_fsa = k2.linear_fsa(unique_word_seq) - - word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa) - - am_scores, ngram_lm_scores = compute_am_and_lm_scores( - lattice, word_fsa_with_epsilon_loops, path_to_seq_map - ) - # Now we use the attention decoder to compute another - # score: attention_scores. - # - # To do that, we have to get the input and output for the attention - # decoder. - - # CAUTION: The "tokens" attribute is set in the file - # local/compile_hlg.py - if isinstance(lattice.tokens, torch.Tensor): - token_seq = k2.ragged.index(lattice.tokens, path) - else: - token_seq = lattice.tokens.index(path) - token_seq = token_seq.remove_axis(token_seq.num_axes - 2) - - # Remove epsilons and -1 from token_seq - token_seq = token_seq.remove_values_leq(0) - - # Remove the seq axis. - token_seq = token_seq.remove_axis(0) - - token_seq, _ = token_seq.index( - indexes=new2old, axis=0, need_value_indexes=False - ) - - # Now word in unique_word_seq has its corresponding token IDs. - token_ids = token_seq.tolist() - - num_word_seqs = new2old.numel() - - path_to_seq_map_long = path_to_seq_map.to(torch.long) - expanded_memory = memory.index_select(1, path_to_seq_map_long) + path_to_utt_map = nbest.shape.row_ids(1).to(torch.long) + # the shape of memory is (T, N, C), so we use axis=1 here + expanded_memory = memory.index_select(1, path_to_utt_map) if memory_key_padding_mask is not None: + # The shape of memory_key_padding_mask is (N, T), so we + # use axis=0 here. expanded_memory_key_padding_mask = memory_key_padding_mask.index_select( - 0, path_to_seq_map_long + 0, path_to_utt_map ) else: expanded_memory_key_padding_mask = None + # remove axis corresponding to states. + tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) + tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens) + tokens = tokens.remove_values_leq(0) + token_ids = tokens.tolist() + nll = model.decoder_nll( memory=expanded_memory, memory_key_padding_mask=expanded_memory_key_padding_mask, @@ -850,62 +855,36 @@ def rescore_with_attention_decoder( eos_id=eos_id, ) assert nll.ndim == 2 - assert nll.shape[0] == num_word_seqs + assert nll.shape[0] == len(token_ids) attention_scores = -nll.sum(dim=1) - assert attention_scores.ndim == 1 - assert attention_scores.numel() == num_word_seqs if ngram_lm_scale is None: - ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + ngram_lm_scale_list = [0.01, 0.05, 0.08] + ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] else: ngram_lm_scale_list = [ngram_lm_scale] if attention_scale is None: - attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + attention_scale_list = [0.01, 0.05, 0.08] + attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] else: attention_scale_list = [attention_scale] - path_2axes = path.remove_axis(0) - ans = dict() for n_scale in ngram_lm_scale_list: for a_scale in attention_scale_list: tot_scores = ( - am_scores - + n_scale * ngram_lm_scores + am_scores.values + + n_scale * ngram_lm_scores.values + a_scale * attention_scores ) - ragged_tot_scores = k2.RaggedTensor(seq_to_path_shape, tot_scores) - argmax_indexes = ragged_tot_scores.argmax() - - best_path_indexes = k2.index_select(new2old, argmax_indexes) - - # best_path is a k2.RaggedInt with 2 axes [path][arc_pos] - best_path, _ = path_2axes.index( - indexes=best_path_indexes, axis=0, need_value_indexes=False - ) - - # labels is a k2.RaggedTensor with 2 axes [path][token_id] - # Note that it contains -1s. - labels = k2.ragged.index(lattice.labels.contiguous(), best_path) - - labels = labels.remove_values_eq(-1) - - if isinstance(lattice.aux_labels, torch.Tensor): - aux_labels = k2.index_select( - lattice.aux_labels, best_path.values - ) - else: - aux_labels, _ = lattice.aux_labels.index( - indexes=best_path.values, axis=0, need_value_indexes=False - ) - - best_path_fsa = k2.linear_fsa(labels) - best_path_fsa.aux_labels = aux_labels + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, max_indexes) key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}" - ans[key] = best_path_fsa + ans[key] = best_path return ans diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py index 23ac247e8..b4c87d964 100644 --- a/icefall/graph_compiler.py +++ b/icefall/graph_compiler.py @@ -106,7 +106,7 @@ class CtcTrainingGraphCompiler(object): word_ids_list = [] for text in texts: word_ids = [] - for word in text.split(" "): + for word in text.split(): if word in self.word_table: word_ids.append(self.word_table[word]) else: diff --git a/icefall/utils.py b/icefall/utils.py index cc658ae32..2324201c3 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -186,7 +186,9 @@ def encode_supervisions( return supervision_segments, texts -def get_texts(best_paths: k2.Fsa) -> List[List[int]]: +def get_texts( + best_paths: k2.Fsa, return_ragged: bool = False +) -> Union[List[List[int]], k2.RaggedTensor]: """Extract the texts (as word IDs) from the best-path FSAs. Args: best_paths: @@ -194,6 +196,9 @@ def get_texts(best_paths: k2.Fsa) -> List[List[int]]: containing multiple FSAs, which is expected to be the result of k2.shortest_path (otherwise the returned values won't be meaningful). + return_ragged: + True to return a ragged tensor with two axes [utt][word_id]. + False to return a list-of-list word IDs. Returns: Returns a list of lists of int, containing the label sequences we decoded. @@ -216,7 +221,10 @@ def get_texts(best_paths: k2.Fsa) -> List[List[int]]: aux_labels = aux_labels.remove_values_leq(0) assert aux_labels.num_axes == 2 - return aux_labels.tolist() + if return_ragged: + return aux_labels + else: + return aux_labels.tolist() def store_transcripts( diff --git a/test/test_decode.py b/test/test_decode.py new file mode 100644 index 000000000..7ef127781 --- /dev/null +++ b/test/test_decode.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +You can run this file in one of the two ways: + + (1) cd icefall; pytest test/test_decode.py + (2) cd icefall; ./test/test_decode.py +""" + +import k2 +from icefall.decode import Nbest + + +def test_nbest_from_lattice(): + s = """ + 0 1 1 10 0.1 + 0 1 5 10 0.11 + 0 1 2 20 0.2 + 1 2 3 30 0.3 + 1 2 4 40 0.4 + 2 3 -1 -1 0.5 + 3 + """ + lattice = k2.Fsa.from_str(s, acceptor=False) + lattice = k2.Fsa.from_fsas([lattice, lattice]) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=10, + use_double_scores=True, + lattice_score_scale=0.5, + ) + # each lattice has only 4 distinct paths that have different word sequences: + # 10->30 + # 10->40 + # 20->30 + # 20->40 + # + # So there should be only 4 paths for each lattice in the Nbest object + assert nbest.fsa.shape[0] == 4 * 2 + assert nbest.shape.row_splits(1).tolist() == [0, 4, 8] + + nbest2 = nbest.intersect(lattice) + tot_scores = nbest2.tot_scores() + argmax = tot_scores.argmax() + best_path = k2.index_fsa(nbest2.fsa, argmax) + print(best_path[0]) From 455693aedebf4094fe1e82f26d4264a885b399fe Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 22 Sep 2021 16:37:20 +0800 Subject: [PATCH 12/12] Fix `hasattr` of AttributeDict. (#52) --- icefall/utils.py | 18 +++++++++++++----- test/test_utils.py | 11 +++++++++++ 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/icefall/utils.py b/icefall/utils.py index 2324201c3..23b4dd6c7 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -146,12 +146,20 @@ def get_env_info(): } -# See -# https://stackoverflow.com/questions/4984647/accessing-dict-keys-like-an-attribute # noqa class AttributeDict(dict): - __slots__ = () - __getattr__ = dict.__getitem__ - __setattr__ = dict.__setitem__ + def __getattr__(self, key): + if key in self: + return self[key] + raise AttributeError(f"No such attribute '{key}'") + + def __setattr__(self, key, value): + self[key] = value + + def __delattr__(self, key): + if key in self: + del self[key] + return + raise AttributeError(f"No such attribute '{key}'") def encode_supervisions( diff --git a/test/test_utils.py b/test/test_utils.py index b4c9358fd..7ac52b289 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -108,3 +108,14 @@ def test_attribute_dict(): assert s["b"] == 20 s.c = 100 assert s["c"] == 100 + assert hasattr(s, "a") + assert hasattr(s, "b") + assert getattr(s, "a") == 10 + del s.a + assert hasattr(s, "a") is False + setattr(s, "c", 100) + s.c = 100 + try: + del s.a + except AttributeError as ex: + print(f"Caught exception: {ex}")