diff --git a/.github/workflows/run-pretrained.yml b/.github/workflows/run-pretrained.yml new file mode 100644 index 000000000..97d3c32d2 --- /dev/null +++ b/.github/workflows/run-pretrained.yml @@ -0,0 +1,106 @@ +# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-pre-trained-conformer-ctc + +on: + push: + branches: + - master + pull_request: + types: [labeled] + +jobs: + run_pre_trained_conformer_ctc: + if: github.event.label.name == 'ready' || github.event_name == 'push' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-18.04] + python-version: [3.6, 3.7, 3.8, 3.9] + torch: ["1.8.1"] + k2-version: ["1.9.dev20210919"] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Python dependencies + run: | + python3 -m pip install --upgrade pip pytest + pip install torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/torch_stable.html + pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/ + + python3 -m pip install git+https://github.com/lhotse-speech/lhotse + python3 -m pip install kaldifeat + # We are in ./icefall and there is a file: requirements.txt in it + pip install -r requirements.txt + + - name: Install graphviz + shell: bash + run: | + python3 -m pip install -qq graphviz + sudo apt-get -qq install graphviz + + - name: Download pre-trained model + shell: bash + run: | + sudo apt-get -qq install git-lfs tree sox + cd egs/librispeech/ASR + mkdir tmp + cd tmp + git lfs install + git clone https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500 + cd .. + tree tmp + soxi tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/*.flac + ls -lh tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/*.flac + + - name: Run CTC decoding + shell: bash + run: | + export PYTHONPATH=$PWD:PYTHONPATH + cd egs/librispeech/ASR + ./conformer_ctc/pretrained.py \ + --num-classes 500 \ + --checkpoint ./tmp/icefall-asr-conformer-ctc-bpe-500/exp/pretrained.pt \ + --bpe-model ./tmp/icefall-asr-conformer-ctc-bpe-500/data/lang_bpe_500/bpe.model \ + --method ctc-decoding \ + ./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1089-134686-0001.flac \ + ./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0001.flac \ + ./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0002.flac + + - name: Run HLG decoding + shell: bash + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + cd egs/librispeech/ASR + ./conformer_ctc/pretrained.py \ + --num-classes 500 \ + --checkpoint ./tmp/icefall-asr-conformer-ctc-bpe-500/exp/pretrained.pt \ + --words-file ./tmp/icefall-asr-conformer-ctc-bpe-500/data/lang_bpe_500/words.txt \ + --HLG ./tmp/icefall-asr-conformer-ctc-bpe-500/data/lang_bpe_500/HLG.pt \ + ./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1089-134686-0001.flac \ + ./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0001.flac \ + ./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0002.flac diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 150b5258a..6b72b5a0c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -46,10 +46,18 @@ jobs: with: python-version: ${{ matrix.python-version }} + - name: Install libnsdfile and libsox + if: startsWith(matrix.os, 'ubuntu') + run: | + sudo apt update + sudo apt install -q -y libsndfile1-dev libsndfile1 ffmpeg + sudo apt install -q -y --fix-missing sox libsox-dev libsox-fmt-all + - name: Install Python dependencies run: | python3 -m pip install --upgrade pip pytest pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/ + pip install git+https://github.com/lhotse-speech/lhotse # icefall requirements pip install -r requirements.txt @@ -84,3 +92,7 @@ jobs: echo "lib_path: $lib_path" export DYLD_LIBRARY_PATH=$lib_path:$DYLD_LIBRARY_PATH pytest ./test + + # runt tests for conformer ctc + cd egs/librispeech/ASR/conformer_ctc + pytest diff --git a/.gitignore b/.gitignore index e6c84ca5e..f4f703243 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +icefall.egg-info/ data __pycache__ path.sh diff --git a/README.md b/README.md index dc03c5883..298feca2e 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,22 @@ The WER for this model is: We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing) + +## Deployment with C++ + +Once you have trained a model in icefall, you may want to deploy it with C++, +without Python dependencies. + +Please refer to the documentation + +for how to do this. + +We also provide a Colab notebook, showing you how to run a torch scripted model in [k2][k2] with C++. +Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1BIGLWzS36isskMXHKcqC9ysN6pspYXs_?usp=sharing) + + [LibriSpeech_tdnn_lstm_ctc]: egs/librispeech/ASR/tdnn_lstm_ctc [LibriSpeech_conformer_ctc]: egs/librispeech/ASR/conformer_ctc [yesno]: egs/yesno/ASR [librispeech]: egs/librispeech/ASR +[k2]: https://github.com/k2-fsa/k2 diff --git a/docs/source/recipes/librispeech/conformer_ctc.rst b/docs/source/recipes/librispeech/conformer_ctc.rst index 40100bc5a..57ac246e1 100644 --- a/docs/source/recipes/librispeech/conformer_ctc.rst +++ b/docs/source/recipes/librispeech/conformer_ctc.rst @@ -1,4 +1,4 @@ -Confromer CTC +Conformer CTC ============= This tutorial shows you how to run a conformer ctc model @@ -20,6 +20,7 @@ In this tutorial, you will learn: - (2) How to start the training, either with a single GPU or multiple GPUs - (3) How to do decoding after training, with n-gram LM rescoring and attention decoder rescoring - (4) How to use a pre-trained model, provided by us + - (5) How to deploy your trained model in C++, without Python dependencies Data preparation ---------------- @@ -292,16 +293,25 @@ The commonly used options are: - ``--method`` - This specifies the decoding method. + This specifies the decoding method. This script supports 7 decoding methods. + As for ctc decoding, it uses a sentence piece model to convert word pieces to words. + And it needs neither a lexicon nor an n-gram LM. - The following command uses attention decoder for rescoring: + For example, the following command uses CTC topology for decoding: .. code-block:: $ cd egs/librispeech/ASR - $ ./conformer_ctc/decode.py --method attention-decoder --max-duration 30 --lattice-score-scale 0.5 + $ ./conformer_ctc/decode.py --method ctc-decoding --max-duration 300 - - ``--lattice-score-scale`` + And the following command uses attention decoder for rescoring: + + .. code-block:: + + $ cd egs/librispeech/ASR + $ ./conformer_ctc/decode.py --method attention-decoder --max-duration 30 --nbest-scale 0.5 + + - ``--nbest-scale`` It is used to scale down lattice scores so that there are more unique paths for rescoring. @@ -311,6 +321,61 @@ The commonly used options are: It has the same meaning as the one during training. A larger value may cause OOM. +Here are some results for CTC decoding with a vocab size of 500: + +Usage: + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./conformer_ctc/decode.py \ + --epoch 25 \ + --avg 1 \ + --max-duration 300 \ + --exp-dir conformer_ctc/exp \ + --lang-dir data/lang_bpe_500 \ + --method ctc-decoding + +The output is given below: + +.. code-block:: bash + + 2021-09-26 12:44:31,033 INFO [decode.py:537] Decoding started + 2021-09-26 12:44:31,033 INFO [decode.py:538] + {'lm_dir': PosixPath('data/lm'), 'subsampling_factor': 4, 'vgg_frontend': False, 'use_feat_batchnorm': True, + 'feature_dim': 80, 'nhead': 8, 'attention_dim': 512, 'num_decoder_layers': 6, 'search_beam': 20, 'output_beam': 8, + 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, + 'epoch': 25, 'avg': 1, 'method': 'ctc-decoding', 'num_paths': 100, 'nbest_scale': 0.5, + 'export': False, 'exp_dir': PosixPath('conformer_ctc/exp'), 'lang_dir': PosixPath('data/lang_bpe_500'), 'full_libri': False, + 'feature_dir': PosixPath('data/fbank'), 'max_duration': 100, 'bucketing_sampler': False, 'num_buckets': 30, + 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, + 'shuffle': True, 'return_cuts': True, 'num_workers': 2} + 2021-09-26 12:44:31,406 INFO [lexicon.py:113] Loading pre-compiled data/lang_bpe_500/Linv.pt + 2021-09-26 12:44:31,464 INFO [decode.py:548] device: cuda:0 + 2021-09-26 12:44:36,171 INFO [checkpoint.py:92] Loading checkpoint from conformer_ctc/exp/epoch-25.pt + 2021-09-26 12:44:36,776 INFO [decode.py:652] Number of model parameters: 109226120 + 2021-09-26 12:44:37,714 INFO [decode.py:473] batch 0/206, cuts processed until now is 12 + 2021-09-26 12:45:15,944 INFO [decode.py:473] batch 100/206, cuts processed until now is 1328 + 2021-09-26 12:45:54,443 INFO [decode.py:473] batch 200/206, cuts processed until now is 2563 + 2021-09-26 12:45:56,411 INFO [decode.py:494] The transcripts are stored in conformer_ctc/exp/recogs-test-clean-ctc-decoding.txt + 2021-09-26 12:45:56,592 INFO [utils.py:331] [test-clean-ctc-decoding] %WER 3.26% [1715 / 52576, 163 ins, 128 del, 1424 sub ] + 2021-09-26 12:45:56,807 INFO [decode.py:506] Wrote detailed error stats to conformer_ctc/exp/errs-test-clean-ctc-decoding.txt + 2021-09-26 12:45:56,808 INFO [decode.py:522] + For test-clean, WER of different settings are: + ctc-decoding 3.26 best for test-clean + + 2021-09-26 12:45:57,362 INFO [decode.py:473] batch 0/203, cuts processed until now is 15 + 2021-09-26 12:46:35,565 INFO [decode.py:473] batch 100/203, cuts processed until now is 1477 + 2021-09-26 12:47:15,106 INFO [decode.py:473] batch 200/203, cuts processed until now is 2922 + 2021-09-26 12:47:16,131 INFO [decode.py:494] The transcripts are stored in conformer_ctc/exp/recogs-test-other-ctc-decoding.txt + 2021-09-26 12:47:16,208 INFO [utils.py:331] [test-other-ctc-decoding] %WER 8.21% [4295 / 52343, 396 ins, 315 del, 3584 sub ] + 2021-09-26 12:47:16,432 INFO [decode.py:506] Wrote detailed error stats to conformer_ctc/exp/errs-test-other-ctc-decoding.txt + 2021-09-26 12:47:16,432 INFO [decode.py:522] + For test-other, WER of different settings are: + ctc-decoding 8.21 best for test-other + + 2021-09-26 12:47:16,433 INFO [decode.py:680] Done! + Pre-trained Model ----------------- @@ -381,7 +446,6 @@ After downloading, you will have the following files: 6 directories, 11 files **File descriptions**: - - ``data/lang_bpe/HLG.pt`` It is the decoding graph. @@ -462,12 +526,58 @@ Usage displays the help information. -It supports three decoding methods: +It supports 4 decoding methods: + - CTC decoding - HLG decoding - HLG + n-gram LM rescoring - HLG + n-gram LM rescoring + attention decoder rescoring +CTC decoding +^^^^^^^^^^^^ + +CTC decoding uses the best path of the decoding lattice as the decoding result +without any LM or lexicon. + +The command to run CTC decoding is: + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./conformer_ctc/pretrained.py \ + --checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \ + --bpe-model ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/bpe.model \ + --method ctc-decoding \ + ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac \ + ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac \ + ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac + +The output is given below: + +.. code-block:: + + 2021-10-13 11:21:50,896 INFO [pretrained.py:236] device: cuda:0 + 2021-10-13 11:21:50,896 INFO [pretrained.py:238] Creating model + 2021-10-13 11:21:56,669 INFO [pretrained.py:255] Constructing Fbank computer + 2021-10-13 11:21:56,670 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac'] + 2021-10-13 11:21:56,683 INFO [pretrained.py:271] Decoding started + 2021-10-13 11:21:57,341 INFO [pretrained.py:290] Building CTC topology + 2021-10-13 11:21:57,625 INFO [lexicon.py:113] Loading pre-compiled tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/Linv.pt + 2021-10-13 11:21:57,679 INFO [pretrained.py:299] Loading BPE model + 2021-10-13 11:22:00,076 INFO [pretrained.py:314] Use CTC decoding + 2021-10-13 11:22:00,087 INFO [pretrained.py:400] + ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac: + AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS + + ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac: + GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONOURED + BOSOM TO CONNECT HER PARENT FOR EVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN + + ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac: + YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION + + 2021-10-13 11:22:00,087 INFO [pretrained.py:402] Decoding Done + HLG decoding ^^^^^^^^^^^^ @@ -490,14 +600,14 @@ The output is given below: .. code-block:: - 2021-08-20 11:03:05,712 INFO [pretrained.py:217] device: cuda:0 - 2021-08-20 11:03:05,712 INFO [pretrained.py:219] Creating model - 2021-08-20 11:03:11,345 INFO [pretrained.py:238] Loading HLG from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt - 2021-08-20 11:03:18,442 INFO [pretrained.py:255] Constructing Fbank computer - 2021-08-20 11:03:18,444 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac'] - 2021-08-20 11:03:18,507 INFO [pretrained.py:271] Decoding started - 2021-08-20 11:03:18,795 INFO [pretrained.py:300] Use HLG decoding - 2021-08-20 11:03:19,149 INFO [pretrained.py:339] + 2021-10-13 11:25:19,458 INFO [pretrained.py:236] device: cuda:0 + 2021-10-13 11:25:19,458 INFO [pretrained.py:238] Creating model + 2021-10-13 11:25:25,342 INFO [pretrained.py:255] Constructing Fbank computer + 2021-10-13 11:25:25,343 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac'] + 2021-10-13 11:25:25,356 INFO [pretrained.py:271] Decoding started + 2021-10-13 11:25:26,026 INFO [pretrained.py:327] Loading HLG from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt + 2021-10-13 11:25:33,735 INFO [pretrained.py:359] Use HLG decoding + 2021-10-13 11:25:34,013 INFO [pretrained.py:400] ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac: AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS @@ -508,7 +618,7 @@ The output is given below: ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac: YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION - 2021-08-20 11:03:19,149 INFO [pretrained.py:341] Decoding Done + 2021-10-13 11:25:34,014 INFO [pretrained.py:402] Decoding Done HLG decoding + LM rescoring ^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -536,15 +646,15 @@ Its output is: .. code-block:: - 2021-08-20 11:12:17,565 INFO [pretrained.py:217] device: cuda:0 - 2021-08-20 11:12:17,565 INFO [pretrained.py:219] Creating model - 2021-08-20 11:12:23,728 INFO [pretrained.py:238] Loading HLG from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt - 2021-08-20 11:12:30,035 INFO [pretrained.py:246] Loading G from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt - 2021-08-20 11:13:10,779 INFO [pretrained.py:255] Constructing Fbank computer - 2021-08-20 11:13:10,787 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac'] - 2021-08-20 11:13:10,798 INFO [pretrained.py:271] Decoding started - 2021-08-20 11:13:11,085 INFO [pretrained.py:305] Use HLG decoding + LM rescoring - 2021-08-20 11:13:11,736 INFO [pretrained.py:339] + 2021-10-13 11:28:19,129 INFO [pretrained.py:236] device: cuda:0 + 2021-10-13 11:28:19,129 INFO [pretrained.py:238] Creating model + 2021-10-13 11:28:23,531 INFO [pretrained.py:255] Constructing Fbank computer + 2021-10-13 11:28:23,532 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac'] + 2021-10-13 11:28:23,544 INFO [pretrained.py:271] Decoding started + 2021-10-13 11:28:24,141 INFO [pretrained.py:327] Loading HLG from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt + 2021-10-13 11:28:30,752 INFO [pretrained.py:338] Loading G from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt + 2021-10-13 11:28:48,308 INFO [pretrained.py:364] Use HLG decoding + LM rescoring + 2021-10-13 11:28:48,815 INFO [pretrained.py:400] ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac: AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS @@ -555,7 +665,7 @@ Its output is: ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac: YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION - 2021-08-20 11:13:11,737 INFO [pretrained.py:341] Decoding Done + 2021-10-13 11:28:48,815 INFO [pretrained.py:402] Decoding Done HLG decoding + LM rescoring + attention decoder rescoring ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -577,7 +687,7 @@ The command to run HLG decoding + LM rescoring + attention decoder rescoring is: --G ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt \ --ngram-lm-scale 1.3 \ --attention-decoder-scale 1.2 \ - --lattice-score-scale 0.5 \ + --nbest-scale 0.5 \ --num-paths 100 \ --sos-id 1 \ --eos-id 1 \ @@ -589,15 +699,15 @@ The output is below: .. code-block:: - 2021-08-20 11:19:11,397 INFO [pretrained.py:217] device: cuda:0 - 2021-08-20 11:19:11,397 INFO [pretrained.py:219] Creating model - 2021-08-20 11:19:17,354 INFO [pretrained.py:238] Loading HLG from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt - 2021-08-20 11:19:24,615 INFO [pretrained.py:246] Loading G from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt - 2021-08-20 11:20:04,576 INFO [pretrained.py:255] Constructing Fbank computer - 2021-08-20 11:20:04,584 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac'] - 2021-08-20 11:20:04,595 INFO [pretrained.py:271] Decoding started - 2021-08-20 11:20:04,854 INFO [pretrained.py:313] Use HLG + LM rescoring + attention decoder rescoring - 2021-08-20 11:20:05,805 INFO [pretrained.py:339] + 2021-10-13 11:29:50,106 INFO [pretrained.py:236] device: cuda:0 + 2021-10-13 11:29:50,106 INFO [pretrained.py:238] Creating model + 2021-10-13 11:29:56,063 INFO [pretrained.py:255] Constructing Fbank computer + 2021-10-13 11:29:56,063 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac'] + 2021-10-13 11:29:56,077 INFO [pretrained.py:271] Decoding started + 2021-10-13 11:29:56,770 INFO [pretrained.py:327] Loading HLG from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt + 2021-10-13 11:30:04,023 INFO [pretrained.py:338] Loading G from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt + 2021-10-13 11:30:18,163 INFO [pretrained.py:372] Use HLG + LM rescoring + attention decoder rescoring + 2021-10-13 11:30:19,367 INFO [pretrained.py:400] ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac: AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS @@ -608,7 +718,7 @@ The output is below: ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac: YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION - 2021-08-20 11:20:05,805 INFO [pretrained.py:341] Decoding Done + 2021-10-13 11:30:19,367 INFO [pretrained.py:402] Decoding Done Colab notebook -------------- @@ -629,3 +739,119 @@ We do provide a colab notebook for this recipe showing how to use a pre-trained **Congratulations!** You have finished the librispeech ASR recipe with conformer CTC models in ``icefall``. + +If you want to deploy your trained model in C++, please read the following section. + +Deployment with C++ +------------------- + +This section describes how to deploy your trained model in C++, without +Python dependencies. + +We assume you have run ``./prepare.sh`` and have the following directories available: + +.. code-block:: bash + + data + |-- lang_bpe + +Also, we assume your checkpoints are saved in ``conformer_ctc/exp``. + +If you know that averaging 20 checkpoints starting from ``epoch-30.pt`` yields the +lowest WER, you can run the following commands + +.. code-block:: + + $ cd egs/librispeech/ASR + $ ./conformer_ctc/export.py \ + --epoch 30 \ + --avg 20 \ + --jit 1 \ + --lang-dir data/lang_bpe \ + --exp-dir conformer_ctc/exp + +to get a torch scripted model saved in ``conformer_ctc/exp/cpu_jit.pt``. + +Now you have all needed files ready. Let us compile k2 from source: + +.. code-block:: bash + + $ cd $HOME + $ git clone https://github.com/k2-fsa/k2 + $ cd k2 + $ git checkout v2.0-pre + +.. CAUTION:: + + You have to switch to the branch ``v2.0-pre``! + +.. code-block:: bash + + $ mkdir build-release + $ cd build-release + $ cmake -DCMAKE_BUILD_TYPE=Release .. + $ make -j decode + # You will find an executable: `./bin/decode` + +Now you are ready to go! + +To view the usage of ``./bin/decode``, run: + +.. code-block:: + + $ ./bin/decode + +It will show you the following message: + +.. code-block:: + + Please provide --jit_pt + + (1) CTC decoding + ./bin/decode \ + --use_ctc_decoding true \ + --jit_pt \ + --bpe_model \ + /path/to/foo.wav \ + /path/to/bar.wav \ + + (2) HLG decoding + ./bin/decode \ + --use_ctc_decoding false \ + --jit_pt \ + --hlg \ + --word-table \ + /path/to/foo.wav \ + /path/to/bar.wav \ + + + --use_gpu false to use CPU + --use_gpu true to use GPU + +``./bin/decode`` supports two types of decoding at present: CTC decoding and HLG decoding. + +CTC decoding +^^^^^^^^^^^^ + +You need to provide: + + - ``--jit_pt``, this is the file generated by ``conformer_ctc/export.py``. You can find it + in ``conformer_ctc/exp/cpu_jit.pt``. + - ``--bpe_model``, this is a sentence piece model generated by ``prepare.sh``. You can find + it in ``data/lang_bpe/bpe.model``. + + +HLG decoding +^^^^^^^^^^^^ + +You need to provide: + + - ``--jit_pt``, this is the same file as in CTC decoding. + - ``--hlg``, this file is generated by ``prepare.sh``. You can find it in ``data/lang_bpe/HLG.pt``. + - ``--word-table``, this file is generated by ``prepare.sh``. You can find it in ``data/lang_bpe/words.txt``. + +We do provide a Colab notebook, showing you how to run a torch scripted model in C++. +Please see |librispeech asr conformer ctc torch script colab notebook| + +.. |librispeech asr conformer ctc torch script colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg + :target: https://colab.research.google.com/drive/1BIGLWzS36isskMXHKcqC9ysN6pspYXs_?usp=sharing diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index d04e912bf..eb679b951 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -38,14 +38,16 @@ python conformer_ctc/train.py --bucketing-sampler True \ --concatenate-cuts False \ --max-duration 200 \ --full-libri True \ - --world-size 4 + --world-size 4 \ + --lang-dir data/lang_bpe_5000 -python conformer_ctc/decode.py --lattice-score-scale 0.5 \ +python conformer_ctc/decode.py --nbest-scale 0.5 \ --epoch 34 \ --avg 20 \ --method attention-decoder \ --max-duration 20 \ - --num-paths 100 + --num-paths 100 \ + --lang-dir data/lang_bpe_5000 ``` ### LibriSpeech training results (Tdnn-Lstm) diff --git a/egs/librispeech/ASR/conformer_ctc/README.md b/egs/librispeech/ASR/conformer_ctc/README.md index 23b51167b..164c3e53e 100644 --- a/egs/librispeech/ASR/conformer_ctc/README.md +++ b/egs/librispeech/ASR/conformer_ctc/README.md @@ -1,3 +1,53 @@ +## Introduction + Please visit for how to run this recipe. + +## How to compute framewise alignment information + +### Step 1: Train a model + +Please use `conformer_ctc/train.py` to train a model. +See +for how to do it. + +### Step 2: Compute framewise alignment + +Run + +``` +# Choose a checkpoint and determine the number of checkpoints to average +epoch=30 +avg=15 +./conformer_ctc/ali.py \ + --epoch $epoch \ + --avg $avg \ + --max-duration 500 \ + --bucketing-sampler 0 \ + --full-libri 1 \ + --exp-dir conformer_ctc/exp \ + --lang-dir data/lang_bpe_5000 \ + --ali-dir data/ali_5000 +``` +and you will get four files inside the folder `data/ali_5000`: + +``` +$ ls -lh data/ali_500 +total 546M +-rw-r--r-- 1 kuangfangjun root 1.1M Sep 28 08:06 test_clean.pt +-rw-r--r-- 1 kuangfangjun root 1.1M Sep 28 08:07 test_other.pt +-rw-r--r-- 1 kuangfangjun root 542M Sep 28 11:36 train-960.pt +-rw-r--r-- 1 kuangfangjun root 2.1M Sep 28 11:38 valid.pt +``` + +**Note**: It can take more than 3 hours to compute the alignment +for the training dataset, which contains 960 * 3 = 2880 hours of data. + +**Caution**: The model parameters in `conformer_ctc/ali.py` have to match those +in `conformer_ctc/train.py`. + +**Caution**: You have to set the parameter `preserve_id` to `True` for `CutMix`. +Search `./conformer_ctc/asr_datamodule.py` for `preserve_id`. + +**TODO:** Add doc about how to use the extracted alignment in the other pull-request. diff --git a/egs/librispeech/ASR/conformer_ctc/ali.py b/egs/librispeech/ASR/conformer_ctc/ali.py new file mode 100755 index 000000000..3d817a8f6 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/ali.py @@ -0,0 +1,314 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +import k2 +import torch +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.decode import one_best_decoding +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + encode_supervisions, + get_alignments, + get_env_info, + save_alignments, + setup_logger, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=34, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=20, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_5000", + help="The lang dir", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--ali-dir", + type=str, + default="data/ali_500", + help="The experiment dir", + ) + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "lm_dir": Path("data/lm"), + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "subsampling_factor": 4, + "num_decoder_layers": 6, + "vgg_frontend": False, + "use_feat_batchnorm": True, + "output_beam": 10, + "use_double_scores": True, + "env_info": get_env_info(), + } + ) + return params + + +def compute_alignments( + model: torch.nn.Module, + dl: torch.utils.data.DataLoader, + params: AttributeDict, + graph_compiler: BpeCtcTrainingGraphCompiler, +) -> List[Tuple[str, List[int]]]: + """Compute the framewise alignments of a dataset. + + Args: + model: + The neural network model. + dl: + Dataloader containing the dataset. + params: + Parameters for computing alignments. + graph_compiler: + It converts token IDs to decoding graphs. + Returns: + Return a list of tuples. Each tuple contains two entries: + - Utterance ID + - Framewise alignments (token IDs) after subsampling + """ + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + num_cuts = 0 + + device = graph_compiler.device + ans = [] + for batch_idx, batch in enumerate(dl): + feature = batch["inputs"] + + # at entry, feature is [N, T, C] + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + + cut_ids = [] + for cut in supervisions["cut"]: + assert len(cut.supervisions) == 1 + cut_ids.append(cut.id) + + nnet_output, encoder_memory, memory_mask = model(feature, supervisions) + # nnet_output is [N, T, C] + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=params.subsampling_factor + ) + # we need also to sort cut_ids as encode_supervisions() + # reorders "texts". + # In general, new2old is an identity map since lhotse sorts the returned + # cuts by duration in descending order + new2old = supervision_segments[:, 0].tolist() + cut_ids = [cut_ids[i] for i in new2old] + + token_ids = graph_compiler.texts_to_ids(texts) + decoding_graph = graph_compiler.compile(token_ids) + + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + lattice = k2.intersect_dense( + decoding_graph, + dense_fsa_vec, + params.output_beam, + ) + + best_path = one_best_decoding( + lattice=lattice, + use_double_scores=params.use_double_scores, + ) + + ali_ids = get_alignments(best_path) + assert len(ali_ids) == len(cut_ids) + ans += list(zip(cut_ids, ali_ids)) + + num_cuts += len(ali_ids) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + assert args.return_cuts is True + assert args.concatenate_cuts is False + if args.full_libri is False: + print("Changing --full-libri to True") + args.full_libri = True + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log/ali") + + logging.info("Computing alignment - started") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + + logging.info("About to create model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames)) + + model.to(device) + model.eval() + + librispeech = LibriSpeechAsrDataModule(args) + + train_dl = librispeech.train_dataloaders() + valid_dl = librispeech.valid_dataloaders() + test_dl = librispeech.test_dataloaders() # a list + + ali_dir = Path(params.ali_dir) + ali_dir.mkdir(exist_ok=True) + + enabled_datasets = { + "test_clean": test_dl[0], + "test_other": test_dl[1], + "train-960": train_dl, + "valid": valid_dl, + } + # For train-960, it takes about 3 hours 40 minutes, i.e., 3.67 hours to + # compute the alignments if you use --max-duration=500 + # + # There are 960 * 3 = 2880 hours data and it takes only + # 3 hours 40 minutes to get the alignment. + # The RTF is roughly: 3.67 / 2880 = 0.0012743 + # + # At the end, you would see + # 2021-09-28 11:32:46,690 INFO [ali.py:188] batch 21000/?, cuts processed until now is 836270 # noqa + # 2021-09-28 11:33:45,084 INFO [ali.py:188] batch 21100/?, cuts processed until now is 840268 # noqa + for name, dl in enabled_datasets.items(): + logging.info(f"Processing {name}") + if name == "train-960": + logging.info( + f"It will take about 3 hours 40 minutes for {name}, " + "which contains 960 * 3 = 2880 hours of data" + ) + alignments = compute_alignments( + model=model, + dl=dl, + params=params, + graph_compiler=graph_compiler, + ) + num_utt = len(alignments) + alignments = dict(alignments) + assert num_utt == len(alignments) + filename = ali_dir / f"{name}.pt" + save_alignments( + alignments=alignments, + subsampling_factor=params.subsampling_factor, + filename=filename, + ) + logging.info( + f"For dataset {name}, its alignments are saved to {filename}" + ) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index b5b41c82e..bddb832b0 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -23,6 +23,7 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple import k2 +import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule @@ -42,6 +43,7 @@ from icefall.decode import ( from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + get_env_info, get_texts, setup_logger, store_transcripts, @@ -77,6 +79,9 @@ def get_parser(): default="attention-decoder", help="""Decoding method. Supported values are: + - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. - (1) 1best. Extract the best path from the decoding lattice as the decoding result. - (2) nbest. Extract n paths from the decoding lattice; the path @@ -106,7 +111,7 @@ def get_parser(): ) parser.add_argument( - "--lattice-score-scale", + "--nbest-scale", type=float, default=0.5, help="""The scale to be applied to `lattice.scores`. @@ -128,14 +133,26 @@ def get_parser(): """, ) + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_5000", + help="The lang dir", + ) + return parser def get_params() -> AttributeDict: params = AttributeDict( { - "exp_dir": Path("conformer_ctc/exp"), - "lang_dir": Path("data/lang_bpe"), "lm_dir": Path("data/lm"), # parameters for conformer "subsampling_factor": 4, @@ -151,6 +168,7 @@ def get_params() -> AttributeDict: "min_active_states": 30, "max_active_states": 10000, "use_double_scores": True, + "env_info": get_env_info(), } ) return params @@ -159,13 +177,15 @@ def get_params() -> AttributeDict: def decode_one_batch( params: AttributeDict, model: nn.Module, - HLG: k2.Fsa, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], batch: dict, word_table: k2.SymbolTable, sos_id: int, eos_id: int, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[int]]]: +) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -190,7 +210,11 @@ def decode_one_batch( model: The neural model. HLG: - The decoding graph. + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. batch: It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation @@ -209,7 +233,10 @@ def decode_one_batch( Return the decoding result. See above description for the format of the returned dict. """ - device = HLG.device + if HLG is not None: + device = HLG.device + else: + device = H.device feature = batch["inputs"] assert feature.ndim == 3 feature = feature.to(device) @@ -229,9 +256,17 @@ def decode_one_batch( 1, ).to(torch.int32) + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + lattice = get_lattice( nnet_output=nnet_output, - HLG=HLG, + decoding_graph=decoding_graph, supervision_segments=supervision_segments, search_beam=params.search_beam, output_beam=params.output_beam, @@ -240,6 +275,24 @@ def decode_one_batch( subsampling_factor=params.subsampling_factor, ) + if params.method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + if params.method == "nbest-oracle": # Note: You can also pass rescored lattices to it. # We choose the HLG decoded lattice for speed reasons @@ -250,12 +303,12 @@ def decode_one_batch( num_paths=params.num_paths, ref_texts=supervisions["text"], word_table=word_table, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, oov="", ) hyps = get_texts(best_path) hyps = [[word_table[i] for i in ids] for ids in hyps] - key = f"oracle_{params.num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa return {key: hyps} if params.method in ["1best", "nbest"]: @@ -269,9 +322,9 @@ def decode_one_batch( lattice=lattice, num_paths=params.num_paths, use_double_scores=params.use_double_scores, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) - key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa hyps = get_texts(best_path) hyps = [[word_table[i] for i in ids] for ids in hyps] @@ -293,7 +346,7 @@ def decode_one_batch( G=G, num_paths=params.num_paths, lm_scale_list=lm_scale_list, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) elif params.method == "whole-lattice-rescoring": best_path_dict = rescore_with_whole_lattice( @@ -319,7 +372,7 @@ def decode_one_batch( memory_key_padding_mask=memory_key_padding_mask, sos_id=sos_id, eos_id=eos_id, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) else: assert False, f"Unsupported decoding method: {params.method}" @@ -340,12 +393,14 @@ def decode_dataset( dl: torch.utils.data.DataLoader, params: AttributeDict, model: nn.Module, - HLG: k2.Fsa, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], word_table: k2.SymbolTable, sos_id: int, eos_id: int, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[int], List[int]]]]: +) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. Args: @@ -356,7 +411,11 @@ def decode_dataset( model: The neural model. HLG: - The decoding graph. + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. word_table: It is the word symbol table. sos_id: @@ -391,6 +450,8 @@ def decode_dataset( params=params, model=model, HLG=HLG, + H=H, + bpe_model=bpe_model, batch=batch, word_table=word_table, G=G, @@ -469,6 +530,8 @@ def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) params = get_params() params.update(vars(args)) @@ -496,14 +559,26 @@ def main(): sos_id = graph_compiler.sos_id eos_id = graph_compiler.eos_id - HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") - ) - HLG = HLG.to(device) - assert HLG.requires_grad is False + if params.method == "ctc-decoding": + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") + ) + HLG = HLG.to(device) + assert HLG.requires_grad is False - if not hasattr(HLG, "lm_scores"): - HLG.lm_scores = HLG.scores.clone() + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() if params.method in ( "nbest-rescoring", @@ -593,6 +668,8 @@ def main(): params=params, model=model, HLG=HLG, + H=H, + bpe_model=bpe_model, word_table=lexicon.word_table, G=G, sos_id=sos_id, diff --git a/egs/librispeech/ASR/conformer_ctc/export.py b/egs/librispeech/ASR/conformer_ctc/export.py new file mode 100755 index 000000000..79e026dac --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/export.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. + +import argparse +import logging +from pathlib import Path + +import torch +from conformer import Conformer + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.lexicon import Lexicon +from icefall.utils import AttributeDict, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=34, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=20, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_5000", + help="""It contains language related input files such as "lexicon.txt" + """, + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=True, + help="""True to save a model after applying torch.jit.script. + """, + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "feature_dim": 80, + "subsampling_factor": 4, + "use_feat_batchnorm": True, + "attention_dim": 512, + "nhead": 8, + "num_decoder_layers": 6, + } + ) + return params + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + params = get_params() + params.update(vars(args)) + + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=False, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + model.to(device) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames)) + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index c924b87bb..beed6f73b 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -23,6 +24,7 @@ from typing import List import k2 import kaldifeat +import sentencepiece as spm import torch import torchaudio from conformer import Conformer @@ -34,7 +36,7 @@ from icefall.decode import ( rescore_with_attention_decoder, rescore_with_whole_lattice, ) -from icefall.utils import AttributeDict, get_texts +from icefall.utils import AttributeDict, get_env_info, get_texts def get_parser(): @@ -54,12 +56,25 @@ def get_parser(): parser.add_argument( "--words-file", type=str, - required=True, - help="Path to words.txt", + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, ) parser.add_argument( - "--HLG", type=str, required=True, help="Path to HLG.pt." + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, ) parser.add_argument( @@ -68,6 +83,10 @@ def get_parser(): default="1best", help="""Decoding method. Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a sentence + piece model, i.e., lang_dir/bpe.model, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. (1) 1best - Use the best path as decoding output. Only the transformer encoder output is used for decoding. We call it HLG decoding. @@ -125,7 +144,7 @@ def get_parser(): ) parser.add_argument( - "--lattice-score-scale", + "--nbest-scale", type=float, default=0.5, help=""" @@ -139,7 +158,7 @@ def get_parser(): parser.add_argument( "--sos-id", - type=float, + type=int, default=1, help=""" Used only when method is attention-decoder. @@ -147,9 +166,18 @@ def get_parser(): """, ) + parser.add_argument( + "--num-classes", + type=int, + default=5000, + help=""" + Vocab size in the BPE model. + """, + ) + parser.add_argument( "--eos-id", - type=float, + type=int, default=1, help=""" Used only when method is attention-decoder. @@ -180,7 +208,6 @@ def get_params() -> AttributeDict: "use_feat_batchnorm": True, "feature_dim": 80, "nhead": 8, - "num_classes": 5000, "attention_dim": 512, "num_decoder_layers": 6, # parameters for decoding @@ -223,7 +250,13 @@ def main(): args = parser.parse_args() params = get_params() + if args.method != "attention-decoder": + # to save memory as the attention decoder + # will not be used + params.num_decoder_layers = 0 + params.update(vars(args)) + params["env_info"] = get_env_info() logging.info(f"{params}") device = torch.device("cpu") @@ -245,27 +278,10 @@ def main(): ) checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"]) + model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() - logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) - HLG = HLG.to(device) - if not hasattr(HLG, "lm_scores"): - # For whole-lattice-rescoring and attention-decoder - HLG.lm_scores = HLG.scores.clone() - - if params.method in ["whole-lattice-rescoring", "attention-decoder"]: - logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) - # Add epsilon self-loops to G as we will compose - # it with the whole lattice later - G = G.to(device) - G = k2.add_epsilon_self_loops(G) - G = k2.arc_sort(G) - G.lm_scores = G.scores.clone() - logging.info("Constructing Fbank computer") opts = kaldifeat.FbankOptions() opts.device = device @@ -299,52 +315,108 @@ def main(): dtype=torch.int32, ) - lattice = get_lattice( - nnet_output=nnet_output, - HLG=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(params.bpe_model) + max_token_id = params.num_classes - 1 + + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) - if params.method == "1best": - logging.info("Use HLG decoding") best_path = one_best_decoding( lattice=lattice, use_double_scores=params.use_double_scores ) - elif params.method == "whole-lattice-rescoring": - logging.info("Use HLG decoding + LM rescoring") - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=[params.ngram_lm_scale], - ) - best_path = next(iter(best_path_dict.values())) - elif params.method == "attention-decoder": - logging.info("Use HLG + LM rescoring + attention decoder rescoring") - rescored_lattice = rescore_with_whole_lattice( - lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None - ) - best_path_dict = rescore_with_attention_decoder( - lattice=rescored_lattice, - num_paths=params.num_paths, - model=model, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - sos_id=params.sos_id, - eos_id=params.eos_id, - lattice_score_scale=params.lattice_score_scale, - ngram_lm_scale=params.ngram_lm_scale, - attention_scale=params.attention_decoder_scale, - ) - best_path = next(iter(best_path_dict.values())) + token_ids = get_texts(best_path) + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] + elif params.method in [ + "1best", + "whole-lattice-rescoring", + "attention-decoder", + ]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() - hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + if params.method in [ + "whole-lattice-rescoring", + "attention-decoder", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = G.to(device) + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "attention-decoder": + logging.info("Use HLG + LM rescoring + attention decoder rescoring") + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None + ) + best_path_dict = rescore_with_attention_decoder( + lattice=rescored_lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=params.sos_id, + eos_id=params.eos_id, + nbest_scale=params.nbest_scale, + ngram_lm_scale=params.ngram_lm_scale, + attention_scale=params.attention_decoder_scale, + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") s = "\n" for filename, hyp in zip(params.sound_files, hyps): diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 8c1fc9595..6495cb166 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang) +# Wei Kang +# Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -21,16 +22,16 @@ import argparse import logging from pathlib import Path from shutil import copyfile -from typing import Optional +from typing import Optional, Tuple import k2 import torch -import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer from lhotse.utils import fix_random_seed +from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter @@ -43,7 +44,9 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + MetricsTracker, encode_supervisions, + get_env_info, setup_logger, str2bool, ) @@ -75,6 +78,13 @@ def get_parser(): help="Should various information be logged in tensorboard.", ) + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_5000", + help="lang directory", + ) + parser.add_argument( "--num-epochs", type=int, @@ -92,6 +102,26 @@ def get_parser(): """, ) + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_5000", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + return parser @@ -106,12 +136,6 @@ def get_params() -> AttributeDict: Explanation of options saved in `params`: - - exp_dir: It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - - - lang_dir: It contains language related input files such as - "lexicon.txt" - - best_train_loss: Best training loss so far. It is used to select the model that has the lowest training loss. It is updated during the training. @@ -162,14 +186,12 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { - "exp_dir": Path("conformer_ctc/exp"), - "lang_dir": Path("data/lang_bpe"), "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": 0, - "log_interval": 10, + "log_interval": 50, "reset_interval": 200, "valid_interval": 3000, # parameters for conformer @@ -188,6 +210,7 @@ def get_params() -> AttributeDict: "weight_decay": 1e-6, "lr_factor": 5.0, "warm_step": 80000, + "env_info": get_env_info(), } ) @@ -287,7 +310,7 @@ def compute_loss( batch: dict, graph_compiler: BpeCtcTrainingGraphCompiler, is_training: bool, -): +) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -367,15 +390,17 @@ def compute_loss( loss = ctc_loss att_loss = torch.tensor([0]) - # train_frames and valid_frames are used for printing. - if is_training: - params.train_frames = supervision_segments[:, 2].sum().item() - else: - params.valid_frames = supervision_segments[:, 2].sum().item() - assert loss.requires_grad == is_training - return loss, ctc_loss.detach(), att_loss.detach() + info = MetricsTracker() + info["frames"] = supervision_segments[:, 2].sum().item() + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.att_rate != 0.0: + info["att_loss"] = att_loss.detach().cpu().item() + + info["loss"] = loss.detach().cpu().item() + + return loss, info def compute_validation_loss( @@ -384,18 +409,14 @@ def compute_validation_loss( graph_compiler: BpeCtcTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, -) -> None: - """Run the validation process. The validation loss - is saved in `params.valid_loss`. - """ +) -> MetricsTracker: + """Run the validation process.""" model.eval() - tot_loss = 0.0 - tot_ctc_loss = 0.0 - tot_att_loss = 0.0 - tot_frames = 0.0 + tot_loss = MetricsTracker() + for batch_idx, batch in enumerate(valid_dl): - loss, ctc_loss, att_loss = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, batch=batch, @@ -403,36 +424,17 @@ def compute_validation_loss( is_training=False, ) assert loss.requires_grad is False - assert ctc_loss.requires_grad is False - assert att_loss.requires_grad is False - - loss_cpu = loss.detach().cpu().item() - tot_loss += loss_cpu - - tot_ctc_loss += ctc_loss.detach().cpu().item() - tot_att_loss += att_loss.detach().cpu().item() - - tot_frames += params.valid_frames + tot_loss = tot_loss + loss_info if world_size > 1: - s = torch.tensor( - [tot_loss, tot_ctc_loss, tot_att_loss, tot_frames], - device=loss.device, - ) - dist.all_reduce(s, op=dist.ReduceOp.SUM) - s = s.cpu().tolist() - tot_loss = s[0] - tot_ctc_loss = s[1] - tot_att_loss = s[2] - tot_frames = s[3] + tot_loss.reduce(loss.device) - params.valid_loss = tot_loss / tot_frames - params.valid_ctc_loss = tot_ctc_loss / tot_frames - params.valid_att_loss = tot_att_loss / tot_frames - - if params.valid_loss < params.best_valid_loss: + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = params.valid_loss + params.best_valid_loss = loss_value + + return tot_loss def train_one_epoch( @@ -471,24 +473,21 @@ def train_one_epoch( """ model.train() - tot_loss = 0.0 # sum of losses over all batches - tot_ctc_loss = 0.0 - tot_att_loss = 0.0 + tot_loss = MetricsTracker() - tot_frames = 0.0 # sum of frames over all batches - params.tot_loss = 0.0 - params.tot_frames = 0.0 for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - loss, ctc_loss, att_loss = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, batch=batch, graph_compiler=graph_compiler, is_training=True, ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. @@ -498,75 +497,26 @@ def train_one_epoch( clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() - loss_cpu = loss.detach().cpu().item() - ctc_loss_cpu = ctc_loss.detach().cpu().item() - att_loss_cpu = att_loss.detach().cpu().item() - - tot_frames += params.train_frames - tot_loss += loss_cpu - tot_ctc_loss += ctc_loss_cpu - tot_att_loss += att_loss_cpu - - params.tot_frames += params.train_frames - params.tot_loss += loss_cpu - - tot_avg_loss = tot_loss / tot_frames - tot_avg_ctc_loss = tot_ctc_loss / tot_frames - tot_avg_att_loss = tot_att_loss / tot_frames - if batch_idx % params.log_interval == 0: logging.info( - f"Epoch {params.cur_epoch}, batch {batch_idx}, " - f"batch avg ctc loss {ctc_loss_cpu/params.train_frames:.4f}, " - f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, " - f"batch avg loss {loss_cpu/params.train_frames:.4f}, " - f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, " - f"total avg att loss: {tot_avg_att_loss:.4f}, " - f"total avg loss: {tot_avg_loss:.4f}, " - f"batch size: {batch_size}" + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" ) + if batch_idx % params.log_interval == 0: + if tb_writer is not None: - tb_writer.add_scalar( - "train/current_ctc_loss", - ctc_loss_cpu / params.train_frames, - params.batch_idx_train, + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train ) - tb_writer.add_scalar( - "train/current_att_loss", - att_loss_cpu / params.train_frames, - params.batch_idx_train, + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train ) - tb_writer.add_scalar( - "train/current_loss", - loss_cpu / params.train_frames, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/tot_avg_ctc_loss", - tot_avg_ctc_loss, - params.batch_idx_train, - ) - - tb_writer.add_scalar( - "train/tot_avg_att_loss", - tot_avg_att_loss, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/tot_avg_loss", - tot_avg_loss, - params.batch_idx_train, - ) - if batch_idx > 0 and batch_idx % params.reset_interval == 0: - tot_loss = 0.0 # sum of losses over all batches - tot_ctc_loss = 0.0 - tot_att_loss = 0.0 - - tot_frames = 0.0 # sum of frames over all batches if batch_idx > 0 and batch_idx % params.valid_interval == 0: - compute_validation_loss( + logging.info("Computing validation loss") + valid_info = compute_validation_loss( params=params, model=model, graph_compiler=graph_compiler, @@ -574,33 +524,14 @@ def train_one_epoch( world_size=world_size, ) model.train() - logging.info( - f"Epoch {params.cur_epoch}, " - f"valid ctc loss {params.valid_ctc_loss:.4f}," - f"valid att loss {params.valid_att_loss:.4f}," - f"valid loss {params.valid_loss:.4f}," - f" best valid loss: {params.best_valid_loss:.4f} " - f"best valid epoch: {params.best_valid_epoch}" - ) + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") if tb_writer is not None: - tb_writer.add_scalar( - "train/valid_ctc_loss", - params.valid_ctc_loss, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/valid_att_loss", - params.valid_att_loss, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/valid_loss", - params.valid_loss, - params.batch_idx_train, + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train ) - params.train_loss = params.tot_loss / params.tot_frames - + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value if params.train_loss < params.best_train_loss: params.best_train_epoch = params.cur_epoch params.best_train_loss = params.train_loss @@ -726,6 +657,8 @@ def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) world_size = args.world_size assert world_size >= 1 diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index 68a4ff65c..3e6abb695 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -236,6 +236,7 @@ class Transformer(nn.Module): x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) return x + @torch.jit.export def decoder_forward( self, memory: torch.Tensor, @@ -264,11 +265,15 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device) @@ -301,6 +306,7 @@ class Transformer(nn.Module): return decoder_loss + @torch.jit.export def decoder_nll( self, memory: torch.Tensor, @@ -331,11 +337,15 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) @@ -649,7 +659,8 @@ class PositionalEncoding(nn.Module): self.d_model = d_model self.xscale = math.sqrt(self.d_model) self.dropout = nn.Dropout(p=dropout) - self.pe = None + # not doing: self.pe = None because of errors thrown by torchscript + self.pe = torch.zeros(0, 0, dtype=torch.float32) def extend_pe(self, x: torch.Tensor) -> None: """Extend the time t in the positional encoding if required. @@ -666,8 +677,7 @@ class PositionalEncoding(nn.Module): """ if self.pe is not None: if self.pe.size(1) >= x.size(1): - if self.pe.dtype != x.dtype or self.pe.device != x.device: - self.pe = self.pe.to(dtype=x.dtype, device=x.device) + self.pe = self.pe.to(dtype=x.dtype, device=x.device) return pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) @@ -972,10 +982,7 @@ def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: Return a new list-of-list, where each sublist starts with SOS ID. """ - ans = [] - for utt in token_ids: - ans.append([sos_id] + utt) - return ans + return [[sos_id] + utt for utt in token_ids] def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: @@ -992,7 +999,4 @@ def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: Return a new list-of-list, where each sublist ends with EOS ID. """ - ans = [] - for utt in token_ids: - ans.append(utt + [eos_id]) - return ans + return [utt + [eos_id] for utt in token_ids] diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index c1a532fc1..3b2678ec4 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -40,9 +40,9 @@ dl_dir=$PWD/download # It will generate data/lang_bpe_xxx, # data/lang_bpe_yyy if the array contains xxx, yyy vocab_sizes=( - # 5000 - # 2000 - # 1000 + 5000 + 2000 + 1000 500 ) @@ -59,13 +59,13 @@ log() { log "dl_dir: $dl_dir" if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then - log "stage -1: Download LM" + log "Stage -1: Download LM" [ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm ./local/download_lm.py --out-dir=$dl_dir/lm fi if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "stage 0: Download data" + log "Stage 0: Download data" # If you have pre-downloaded it to /path/to/LibriSpeech, # you can create a symlink @@ -130,7 +130,7 @@ fi if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "State 6: Prepare BPE based lang" + log "Stage 6: Prepare BPE based lang" for vocab_size in ${vocab_sizes[@]}; do lang_dir=data/lang_bpe_${vocab_size} diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 8290e71d1..950eba438 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -162,7 +162,9 @@ class LibriSpeechAsrDataModule(DataModule): cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz") logging.info("About to create train dataset") - transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] + transforms = [ + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ] if self.args.concatenate_cuts: logging.info( f"Using cut concatenation with duration factor " @@ -267,7 +269,7 @@ class LibriSpeechAsrDataModule(DataModule): cut_transforms=transforms, return_cuts=self.args.return_cuts, ) - valid_sampler = SingleCutSampler( + valid_sampler = BucketingSampler( cuts_valid, max_duration=self.args.max_duration, shuffle=False, @@ -300,12 +302,15 @@ class LibriSpeechAsrDataModule(DataModule): else PrecomputedFeatures(), return_cuts=self.args.return_cuts, ) - sampler = SingleCutSampler( - cuts_test, max_duration=self.args.max_duration + sampler = BucketingSampler( + cuts_test, max_duration=self.args.max_duration, shuffle=False ) logging.debug("About to create test dataloader") test_dl = DataLoader( - test, batch_size=None, sampler=sampler, num_workers=1 + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, ) test_loaders.append(test_dl) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 1e91b1008..d9d019743 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -39,6 +39,7 @@ from icefall.decode import ( from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + get_env_info, get_texts, setup_logger, store_transcripts, @@ -97,7 +98,7 @@ def get_parser(): ) parser.add_argument( - "--lattice-score-scale", + "--nbest-scale", type=float, default=0.5, help="""The scale to be applied to `lattice.scores`. @@ -134,6 +135,7 @@ def get_params() -> AttributeDict: "min_active_states": 30, "max_active_states": 10000, "use_double_scores": True, + "env_info": get_env_info(), } ) return params @@ -146,7 +148,7 @@ def decode_one_batch( batch: dict, lexicon: Lexicon, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[int]]]: +) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -210,7 +212,7 @@ def decode_one_batch( lattice = get_lattice( nnet_output=nnet_output, - HLG=HLG, + decoding_graph=HLG, supervision_segments=supervision_segments, search_beam=params.search_beam, output_beam=params.output_beam, @@ -229,7 +231,7 @@ def decode_one_batch( lattice=lattice, num_paths=params.num_paths, use_double_scores=params.use_double_scores, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) key = f"no_rescore-{params.num_paths}" hyps = get_texts(best_path) @@ -248,7 +250,7 @@ def decode_one_batch( G=G, num_paths=params.num_paths, lm_scale_list=lm_scale_list, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) else: best_path_dict = rescore_with_whole_lattice( @@ -272,7 +274,7 @@ def decode_dataset( HLG: k2.Fsa, lexicon: Lexicon, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[int], List[int]]]]: +) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. Args: diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py index 0a543d859..e0d6a7a60 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py @@ -34,7 +34,7 @@ from icefall.decode import ( one_best_decoding, rescore_with_whole_lattice, ) -from icefall.utils import AttributeDict, get_texts +from icefall.utils import AttributeDict, get_env_info, get_texts def get_parser(): @@ -159,6 +159,7 @@ def main(): params = get_params() params.update(vars(args)) + params["env_info"] = get_env_info() logging.info(f"{params}") device = torch.device("cpu") @@ -232,7 +233,7 @@ def main(): lattice = get_lattice( nnet_output=nnet_output, - HLG=HLG, + decoding_graph=HLG, supervision_segments=supervision_segments, search_beam=params.search_beam, output_beam=params.output_beam, diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index 695ee5130..7904b0e61 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -20,17 +21,17 @@ import argparse import logging from pathlib import Path from shutil import copyfile -from typing import Optional +from typing import Optional, Tuple import k2 import torch -import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn import torch.optim as optim from asr_datamodule import LibriSpeechAsrDataModule from lhotse.utils import fix_random_seed from model import TdnnLstm +from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.optim.lr_scheduler import StepLR @@ -43,7 +44,9 @@ from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + MetricsTracker, encode_supervisions, + get_env_info, setup_logger, str2bool, ) @@ -168,6 +171,7 @@ def get_params() -> AttributeDict: "beam_size": 10, "reduction": "sum", "use_double_scores": True, + "env_info": get_env_info(), } ) @@ -267,7 +271,7 @@ def compute_loss( batch: dict, graph_compiler: CtcTrainingGraphCompiler, is_training: bool, -): +) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -324,13 +328,11 @@ def compute_loss( assert loss.requires_grad == is_training - # train_frames and valid_frames are used for printing. - if is_training: - params.train_frames = supervision_segments[:, 2].sum().item() - else: - params.valid_frames = supervision_segments[:, 2].sum().item() + info = MetricsTracker() + info["frames"] = supervision_segments[:, 2].sum().item() + info["loss"] = loss.detach().cpu().item() - return loss + return loss, info def compute_validation_loss( @@ -339,16 +341,16 @@ def compute_validation_loss( graph_compiler: CtcTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, -) -> None: +) -> MetricsTracker: """Run the validation process. The validation loss is saved in `params.valid_loss`. """ model.eval() - tot_loss = 0.0 - tot_frames = 0.0 + tot_loss = MetricsTracker() + for batch_idx, batch in enumerate(valid_dl): - loss = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, batch=batch, @@ -357,22 +359,18 @@ def compute_validation_loss( ) assert loss.requires_grad is False - loss_cpu = loss.detach().cpu().item() - tot_loss += loss_cpu - tot_frames += params.valid_frames + tot_loss = tot_loss + loss_info if world_size > 1: - s = torch.tensor([tot_loss, tot_frames], device=loss.device) - dist.all_reduce(s, op=dist.ReduceOp.SUM) - s = s.cpu().tolist() - tot_loss = s[0] - tot_frames = s[1] + tot_loss.reduce(loss.device) - params.valid_loss = tot_loss / tot_frames + loss_value = tot_loss["loss"] / tot_loss["frames"] - if params.valid_loss < params.best_valid_loss: + if loss_value < params.best_valid_loss: params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = params.valid_loss + params.best_valid_loss = loss_value + + return tot_loss def train_one_epoch( @@ -411,67 +409,45 @@ def train_one_epoch( """ model.train() - tot_loss = 0.0 # reset after params.reset_interval of batches - tot_frames = 0.0 # reset after params.reset_interval of batches - - params.tot_loss = 0.0 - params.tot_frames = 0.0 + tot_loss = MetricsTracker() for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - loss = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, batch=batch, graph_compiler=graph_compiler, is_training=True, ) - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. + # summary stats. + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info optimizer.zero_grad() loss.backward() clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() - loss_cpu = loss.detach().cpu().item() - - tot_frames += params.train_frames - tot_loss += loss_cpu - tot_avg_loss = tot_loss / tot_frames - - params.tot_frames += params.train_frames - params.tot_loss += loss_cpu - if batch_idx % params.log_interval == 0: logging.info( - f"Epoch {params.cur_epoch}, batch {batch_idx}, " - f"batch avg loss {loss_cpu/params.train_frames:.4f}, " - f"total avg loss: {tot_avg_loss:.4f}, " - f"batch size: {batch_size}" + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" ) + if batch_idx % params.log_interval == 0: + if tb_writer is not None: - tb_writer.add_scalar( - "train/current_loss", - loss_cpu / params.train_frames, - params.batch_idx_train, + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train ) - - tb_writer.add_scalar( - "train/tot_avg_loss", - tot_avg_loss, - params.batch_idx_train, + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train ) - if batch_idx > 0 and batch_idx % params.reset_interval == 0: - tot_loss = 0 - tot_frames = 0 - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - compute_validation_loss( + valid_info = compute_validation_loss( params=params, model=model, graph_compiler=graph_compiler, @@ -479,13 +455,16 @@ def train_one_epoch( world_size=world_size, ) model.train() - logging.info( - f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f}," - f" best valid loss: {params.best_valid_loss:.4f} " - f"best valid epoch: {params.best_valid_epoch}" - ) + logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, + "train/valid_", + params.batch_idx_train, + ) - params.train_loss = params.tot_loss / params.tot_frames + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value if params.train_loss < params.best_train_loss: params.best_train_epoch = params.cur_epoch diff --git a/egs/yesno/ASR/prepare.sh b/egs/yesno/ASR/prepare.sh index 9a0cc48bb..8fcee0290 100755 --- a/egs/yesno/ASR/prepare.sh +++ b/egs/yesno/ASR/prepare.sh @@ -24,7 +24,7 @@ log() { log "dl_dir: $dl_dir" if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "stage 0: Download data" + log "Stage 0: Download data" mkdir -p $dl_dir if [ ! -f $dl_dir/waves_yesno/.completed ]; then diff --git a/egs/yesno/ASR/tdnn/asr_datamodule.py b/egs/yesno/ASR/tdnn/asr_datamodule.py index e6614e3ce..832fd556e 100644 --- a/egs/yesno/ASR/tdnn/asr_datamodule.py +++ b/egs/yesno/ASR/tdnn/asr_datamodule.py @@ -20,19 +20,18 @@ from functools import lru_cache from pathlib import Path from typing import List +from torch.utils.data import DataLoader + +from icefall.dataset.datamodule import DataModule +from icefall.utils import str2bool from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse.dataset import ( BucketingSampler, CutConcatenate, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures -from torch.utils.data import DataLoader - -from icefall.dataset.datamodule import DataModule -from icefall.utils import str2bool class YesNoAsrDataModule(DataModule): @@ -198,7 +197,7 @@ class YesNoAsrDataModule(DataModule): ) else: logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + train_sampler = BucketingSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, @@ -226,12 +225,15 @@ class YesNoAsrDataModule(DataModule): else PrecomputedFeatures(), return_cuts=self.args.return_cuts, ) - sampler = SingleCutSampler( - cuts_test, max_duration=self.args.max_duration + sampler = BucketingSampler( + cuts_test, max_duration=self.args.max_duration, shuffle=False ) logging.debug("About to create test dataloader") test_dl = DataLoader( - test, batch_size=None, sampler=sampler, num_workers=1 + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, ) return test_dl diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py index 325acf316..9df019bf5 100755 --- a/egs/yesno/ASR/tdnn/decode.py +++ b/egs/yesno/ASR/tdnn/decode.py @@ -17,6 +17,7 @@ from icefall.decode import get_lattice, one_best_decoding from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + get_env_info, get_texts, setup_logger, store_transcripts, @@ -124,7 +125,7 @@ def decode_one_batch( lattice = get_lattice( nnet_output=nnet_output, - HLG=HLG, + decoding_graph=HLG, supervision_segments=supervision_segments, search_beam=params.search_beam, output_beam=params.output_beam, @@ -256,6 +257,7 @@ def main(): params = get_params() params.update(vars(args)) + params["env_info"] = get_env_info() setup_logger(f"{params.exp_dir}/log/log-decode") logging.info("Decoding started") diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py index fb92110e3..75758b984 100755 --- a/egs/yesno/ASR/tdnn/pretrained.py +++ b/egs/yesno/ASR/tdnn/pretrained.py @@ -29,7 +29,7 @@ from model import Tdnn from torch.nn.utils.rnn import pad_sequence from icefall.decode import get_lattice, one_best_decoding -from icefall.utils import AttributeDict, get_texts +from icefall.utils import AttributeDict, get_env_info, get_texts def get_parser(): @@ -116,6 +116,7 @@ def main(): params = get_params() params.update(vars(args)) + params["env_info"] = get_env_info() logging.info(f"{params}") device = torch.device("cpu") @@ -175,7 +176,7 @@ def main(): lattice = get_lattice( nnet_output=nnet_output, - HLG=HLG, + decoding_graph=HLG, supervision_segments=supervision_segments, search_beam=params.search_beam, output_beam=params.output_beam, diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py index 0f5506d38..e24061fa1 100755 --- a/egs/yesno/ASR/tdnn/train.py +++ b/egs/yesno/ASR/tdnn/train.py @@ -4,17 +4,17 @@ import argparse import logging from pathlib import Path from shutil import copyfile -from typing import Optional +from typing import Optional, Tuple import k2 import torch -import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn import torch.optim as optim from asr_datamodule import YesNoAsrDataModule from lhotse.utils import fix_random_seed from model import Tdnn +from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter @@ -24,7 +24,13 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_env_info, + setup_logger, + str2bool, +) def get_parser(): @@ -122,6 +128,8 @@ def get_params() -> AttributeDict: - valid_interval: Run validation if batch_idx % valid_interval` is 0 + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + - beam_size: It is used in k2.ctc_loss - reduction: It is used in k2.ctc_loss @@ -142,6 +150,7 @@ def get_params() -> AttributeDict: "best_valid_epoch": -1, "batch_idx_train": 0, "log_interval": 10, + "reset_interval": 20, "valid_interval": 10, "beam_size": 10, "reduction": "sum", @@ -245,7 +254,7 @@ def compute_loss( batch: dict, graph_compiler: CtcTrainingGraphCompiler, is_training: bool, -): +) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -305,13 +314,11 @@ def compute_loss( assert loss.requires_grad == is_training - # train_frames and valid_frames are used for printing. - if is_training: - params.train_frames = supervision_segments[:, 2].sum().item() - else: - params.valid_frames = supervision_segments[:, 2].sum().item() + info = MetricsTracker() + info["frames"] = supervision_segments[:, 2].sum().item() + info["loss"] = loss.detach().cpu().item() - return loss + return loss, info def compute_validation_loss( @@ -320,16 +327,16 @@ def compute_validation_loss( graph_compiler: CtcTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, -) -> None: +) -> MetricsTracker: """Run the validation process. The validation loss is saved in `params.valid_loss`. """ model.eval() - tot_loss = 0.0 - tot_frames = 0.0 + tot_loss = MetricsTracker() + for batch_idx, batch in enumerate(valid_dl): - loss = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, batch=batch, @@ -338,22 +345,18 @@ def compute_validation_loss( ) assert loss.requires_grad is False - loss_cpu = loss.detach().cpu().item() - tot_loss += loss_cpu - tot_frames += params.valid_frames + tot_loss = tot_loss + loss_info if world_size > 1: - s = torch.tensor([tot_loss, tot_frames], device=loss.device) - dist.all_reduce(s, op=dist.ReduceOp.SUM) - s = s.cpu().tolist() - tot_loss = s[0] - tot_frames = s[1] + tot_loss.reduce(loss.device) - params.valid_loss = tot_loss / tot_frames + loss_value = tot_loss["loss"] / tot_loss["frames"] - if params.valid_loss < params.best_valid_loss: + if loss_value < params.best_valid_loss: params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = params.valid_loss + params.best_valid_loss = loss_value + + return tot_loss def train_one_epoch( @@ -392,57 +395,45 @@ def train_one_epoch( """ model.train() - tot_loss = 0.0 # sum of losses over all batches - tot_frames = 0.0 # sum of frames over all batches + tot_loss = MetricsTracker() + for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - loss = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, batch=batch, graph_compiler=graph_compiler, is_training=True, ) - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. + # summary stats. + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info optimizer.zero_grad() loss.backward() clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() - loss_cpu = loss.detach().cpu().item() - - tot_frames += params.train_frames - tot_loss += loss_cpu - tot_avg_loss = tot_loss / tot_frames - if batch_idx % params.log_interval == 0: logging.info( - f"Epoch {params.cur_epoch}, batch {batch_idx}, " - f"batch avg loss {loss_cpu/params.train_frames:.4f}, " - f"total avg loss: {tot_avg_loss:.4f}, " - f"batch size: {batch_size}" + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" ) + if batch_idx % params.log_interval == 0: if tb_writer is not None: - tb_writer.add_scalar( - "train/current_loss", - loss_cpu / params.train_frames, - params.batch_idx_train, + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train ) - - tb_writer.add_scalar( - "train/tot_avg_loss", - tot_avg_loss, - params.batch_idx_train, + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: - compute_validation_loss( + valid_info = compute_validation_loss( params=params, model=model, graph_compiler=graph_compiler, @@ -450,19 +441,16 @@ def train_one_epoch( world_size=world_size, ) model.train() - logging.info( - f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f}," - f" best valid loss: {params.best_valid_loss:.4f} " - f"best valid epoch: {params.best_valid_epoch}" - ) + logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}") if tb_writer is not None: - tb_writer.add_scalar( - "train/valid_loss", - params.valid_loss, + valid_info.write_summary( + tb_writer, + "train/valid_", params.batch_idx_train, ) - params.train_loss = tot_loss / tot_frames + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value if params.train_loss < params.best_train_loss: params.best_train_epoch = params.cur_epoch @@ -483,6 +471,7 @@ def run(rank, world_size, args): """ params = get_params() params.update(vars(args)) + params["env_info"] = get_env_info() fix_random_seed(42) if world_size > 1: diff --git a/icefall/decode.py b/icefall/decode.py index e678e4622..62d27dd68 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -66,7 +66,7 @@ def _intersect_device( def get_lattice( nnet_output: torch.Tensor, - HLG: k2.Fsa, + decoding_graph: k2.Fsa, supervision_segments: torch.Tensor, search_beam: float, output_beam: float, @@ -79,8 +79,9 @@ def get_lattice( Args: nnet_output: It is the output of a neural model of shape `(N, T, C)`. - HLG: - An Fsa, the decoding graph. See also `compile_HLG.py`. + decoding_graph: + An Fsa, the decoding graph. It can be either an HLG + (see `compile_HLG.py`) or an H (see `k2.ctc_topo`). supervision_segments: A 2-D **CPU** tensor of dtype `torch.int32` with 3 columns. Each row contains information for a supervision segment. Column 0 @@ -117,7 +118,7 @@ def get_lattice( ) lattice = k2.intersect_dense_pruned( - HLG, + decoding_graph, dense_fsa_vec, search_beam=search_beam, output_beam=output_beam, @@ -180,7 +181,7 @@ class Nbest(object): lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True, - lattice_score_scale: float = 0.5, + nbest_scale: float = 0.5, ) -> "Nbest": """Construct an Nbest object by **sampling** `num_paths` from a lattice. @@ -206,7 +207,7 @@ class Nbest(object): Return an Nbest instance. """ saved_scores = lattice.scores.clone() - lattice.scores *= lattice_score_scale + lattice.scores *= nbest_scale # path is a ragged tensor with dtype torch.int32. # It has three axes [utt][path][arc_pos] path = k2.random_paths( @@ -446,7 +447,7 @@ def nbest_decoding( lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True, - lattice_score_scale: float = 1.0, + nbest_scale: float = 1.0, ) -> k2.Fsa: """It implements something like CTC prefix beam search using n-best lists. @@ -474,7 +475,7 @@ def nbest_decoding( use_double_scores: True to use double precision floating point in the computation. False to use single precision. - lattice_score_scale: + nbest_scale: It's the scale applied to the `lattice.scores`. A smaller value leads to more unique paths at the risk of missing the correct path. Returns: @@ -484,7 +485,7 @@ def nbest_decoding( lattice=lattice, num_paths=num_paths, use_double_scores=use_double_scores, - lattice_score_scale=lattice_score_scale, + nbest_scale=nbest_scale, ) # nbest.fsa.scores contains 0s @@ -505,7 +506,7 @@ def nbest_oracle( ref_texts: List[str], word_table: k2.SymbolTable, use_double_scores: bool = True, - lattice_score_scale: float = 0.5, + nbest_scale: float = 0.5, oov: str = "", ) -> Dict[str, List[List[int]]]: """Select the best hypothesis given a lattice and a reference transcript. @@ -517,7 +518,7 @@ def nbest_oracle( The decoding result returned from this function is the best result that we can obtain using n-best decoding with all kinds of rescoring techniques. - This function is useful to tune the value of `lattice_score_scale`. + This function is useful to tune the value of `nbest_scale`. Args: lattice: @@ -533,7 +534,7 @@ def nbest_oracle( use_double_scores: True to use double precision for computation. False to use single precision. - lattice_score_scale: + nbest_scale: It's the scale applied to the lattice.scores. A smaller value yields more unique paths. oov: @@ -549,7 +550,7 @@ def nbest_oracle( lattice=lattice, num_paths=num_paths, use_double_scores=use_double_scores, - lattice_score_scale=lattice_score_scale, + nbest_scale=nbest_scale, ) hyps = nbest.build_levenshtein_graphs() @@ -590,7 +591,7 @@ def rescore_with_n_best_list( G: k2.Fsa, num_paths: int, lm_scale_list: List[float], - lattice_score_scale: float = 1.0, + nbest_scale: float = 1.0, use_double_scores: bool = True, ) -> Dict[str, k2.Fsa]: """Rescore an n-best list with an n-gram LM. @@ -607,7 +608,7 @@ def rescore_with_n_best_list( Size of nbest list. lm_scale_list: A list of float representing LM score scales. - lattice_score_scale: + nbest_scale: Scale to be applied to ``lattice.score`` when sampling paths using ``k2.random_paths``. use_double_scores: @@ -631,7 +632,7 @@ def rescore_with_n_best_list( lattice=lattice, num_paths=num_paths, use_double_scores=use_double_scores, - lattice_score_scale=lattice_score_scale, + nbest_scale=nbest_scale, ) # nbest.fsa.scores are all 0s at this point @@ -769,7 +770,7 @@ def rescore_with_attention_decoder( memory_key_padding_mask: Optional[torch.Tensor], sos_id: int, eos_id: int, - lattice_score_scale: float = 1.0, + nbest_scale: float = 1.0, ngram_lm_scale: Optional[float] = None, attention_scale: Optional[float] = None, use_double_scores: bool = True, @@ -796,7 +797,7 @@ def rescore_with_attention_decoder( The token ID for SOS. eos_id: The token ID for EOS. - lattice_score_scale: + nbest_scale: It's the scale applied to `lattice.scores`. A smaller value leads to more unique paths at the risk of missing the correct path. ngram_lm_scale: @@ -812,7 +813,7 @@ def rescore_with_attention_decoder( lattice=lattice, num_paths=num_paths, use_double_scores=use_double_scores, - lattice_score_scale=lattice_score_scale, + nbest_scale=nbest_scale, ) # nbest.fsa.scores are all 0s at this point diff --git a/icefall/utils.py b/icefall/utils.py index 1c4dceb0b..287b917c5 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1,4 +1,5 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Mingshuang Luo) # # See ../../LICENSE for clarification regarding multiple authors # @@ -16,6 +17,7 @@ import argparse +import collections import logging import os import subprocess @@ -27,10 +29,12 @@ from pathlib import Path from typing import Any, Dict, Iterable, List, TextIO, Tuple, Union import k2 +import k2.version import kaldialign import lhotse import torch import torch.distributed as dist +from torch.utils.tensorboard import SummaryWriter Pathlike = Union[str, Path] @@ -233,8 +237,8 @@ def encode_supervisions( supervisions: dict, subsampling_factor: int ) -> Tuple[torch.Tensor, List[str]]: """ - Encodes Lhotse's ``batch["supervisions"]`` dict into a pair of torch Tensor, - and a list of transcription strings. + Encodes Lhotse's ``batch["supervisions"]`` dict into + a pair of torch Tensor, and a list of transcription strings. The supervision tensor has shape ``(batch_size, 3)``. Its second dimension contains information about sequence index [0], @@ -302,6 +306,73 @@ def get_texts( return aux_labels.tolist() +def get_alignments(best_paths: k2.Fsa) -> List[List[int]]: + """Extract the token IDs (from best_paths.labels) from the best-path FSAs. + + Args: + best_paths: + A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. + containing multiple FSAs, which is expected to be the result + of k2.shortest_path (otherwise the returned values won't + be meaningful). + Returns: + Returns a list of lists of int, containing the token sequences we + decoded. For `ans[i]`, its length equals to the number of frames + after subsampling of the i-th utterance in the batch. + """ + # arc.shape() has axes [fsa][state][arc], we remove "state"-axis here + label_shape = best_paths.arcs.shape().remove_axis(1) + # label_shape has axes [fsa][arc] + labels = k2.RaggedTensor(label_shape, best_paths.labels.contiguous()) + labels = labels.remove_values_eq(-1) + return labels.tolist() + + +def save_alignments( + alignments: Dict[str, List[int]], + subsampling_factor: int, + filename: str, +) -> None: + """Save alignments to a file. + + Args: + alignments: + A dict containing alignments. Keys of the dict are utterances and + values are the corresponding framewise alignments after subsampling. + subsampling_factor: + The subsampling factor of the model. + filename: + Path to save the alignments. + Returns: + Return None. + """ + ali_dict = { + "subsampling_factor": subsampling_factor, + "alignments": alignments, + } + torch.save(ali_dict, filename) + + +def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]: + """Load alignments from a file. + + Args: + filename: + Path to the file containing alignment information. + The file should be saved by :func:`save_alignments`. + Returns: + Return a tuple containing: + - subsampling_factor: The subsampling_factor used to compute + the alignments. + - alignments: A dict containing utterances and their corresponding + framewise alignment, after subsampling. + """ + ali_dict = torch.load(filename) + subsampling_factor = ali_dict["subsampling_factor"] + alignments = ali_dict["alignments"] + return subsampling_factor, alignments + + def store_transcripts( filename: Pathlike, texts: Iterable[Tuple[str, str]] ) -> None: @@ -339,13 +410,13 @@ def write_error_stats( Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606 reference words (2337 correct) - - The difference between the reference transcript and predicted results. + - The difference between the reference transcript and predicted result. An instance is given below:: THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES - The above example shows that the reference word is `EDISON`, but it is - predicted to `ADDISON` (a substitution error). + The above example shows that the reference word is `EDISON`, + but it is predicted to `ADDISON` (a substitution error). Another example is:: @@ -486,3 +557,76 @@ def write_error_stats( print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) return float(tot_err_rate) + + +class MetricsTracker(collections.defaultdict): + def __init__(self): + # Passing the type 'int' to the base-class constructor + # makes undefined items default to int() which is zero. + # This class will play a role as metrics tracker. + # It can record many metrics, including but not limited to loss. + super(MetricsTracker, self).__init__(int) + + def __add__(self, other: "MetricsTracker") -> "MetricsTracker": + ans = MetricsTracker() + for k, v in self.items(): + ans[k] = v + for k, v in other.items(): + ans[k] = ans[k] + v + return ans + + def __mul__(self, alpha: float) -> "MetricsTracker": + ans = MetricsTracker() + for k, v in self.items(): + ans[k] = v * alpha + return ans + + def __str__(self) -> str: + ans = "" + for k, v in self.norm_items(): + norm_value = "%.4g" % v + ans += str(k) + "=" + str(norm_value) + ", " + frames = str(self["frames"]) + ans += "over " + frames + " frames." + return ans + + def norm_items(self) -> List[Tuple[str, float]]: + """ + Returns a list of pairs, like: + [('ctc_loss', 0.1), ('att_loss', 0.07)] + """ + num_frames = self["frames"] if "frames" in self else 1 + ans = [] + for k, v in self.items(): + if k != "frames": + norm_value = float(v) / num_frames + ans.append((k, norm_value)) + return ans + + def reduce(self, device): + """ + Reduce using torch.distributed, which I believe ensures that + all processes get the total. + """ + keys = sorted(self.keys()) + s = torch.tensor([float(self[k]) for k in keys], device=device) + dist.all_reduce(s, op=dist.ReduceOp.SUM) + for k, v in zip(keys, s.cpu().tolist()): + self[k] = v + + def write_summary( + self, + tb_writer: SummaryWriter, + prefix: str, + batch_idx: int, + ) -> None: + """Add logging information to a TensorBoard writer. + + Args: + tb_writer: a TensorBoard writer + prefix: a prefix for the name of the loss, e.g. "train/valid_", + or "train/current_" + batch_idx: The current batch index, used as the x-axis of the plot. + """ + for k, v in self.norm_items(): + tb_writer.add_scalar(prefix + k, v, batch_idx) diff --git a/setup.py b/setup.py new file mode 100644 index 000000000..6c720e121 --- /dev/null +++ b/setup.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 + +from setuptools import find_packages, setup +from pathlib import Path + +icefall_dir = Path(__file__).parent +install_requires = (icefall_dir / "requirements.txt").read_text().splitlines() + +setup( + name="icefall", + version="1.0", + python_requires=">=3.6.0", + description="Speech processing recipes using k2 and Lhotse.", + author="The k2 and Lhotse Development Team", + license="Apache-2.0 License", + packages=find_packages(), + install_requires=install_requires, + classifiers=[ + "Development Status :: 3 - Alpha", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Intended Audience :: Science/Research", + "Operating System :: POSIX :: Linux", + "License :: OSI Approved :: Apache Software License", + "Topic :: Multimedia :: Sound/Audio :: Speech", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries :: Python Modules", + "Typing :: Typed", + ], +) diff --git a/test/test_decode.py b/test/test_decode.py index 7ef127781..97964ac67 100644 --- a/test/test_decode.py +++ b/test/test_decode.py @@ -43,7 +43,7 @@ def test_nbest_from_lattice(): lattice=lattice, num_paths=10, use_double_scores=True, - lattice_score_scale=0.5, + nbest_scale=0.5, ) # each lattice has only 4 distinct paths that have different word sequences: # 10->30 diff --git a/test/test_utils.py b/test/test_utils.py index 7ac52b289..b8c742c5a 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -20,7 +20,12 @@ import k2 import pytest import torch -from icefall.utils import AttributeDict, encode_supervisions, get_texts +from icefall.utils import ( + AttributeDict, + encode_supervisions, + get_env_info, + get_texts, +) @pytest.fixture @@ -108,6 +113,7 @@ def test_attribute_dict(): assert s["b"] == 20 s.c = 100 assert s["c"] == 100 + assert hasattr(s, "a") assert hasattr(s, "b") assert getattr(s, "a") == 10 @@ -119,3 +125,8 @@ def test_attribute_dict(): del s.a except AttributeError as ex: print(f"Caught exception: {ex}") + + +def test_get_env_info(): + s = get_env_info() + print(s)