diff --git a/.flake8 b/.flake8 index 1c0c2cdbb..410cb5482 100644 --- a/.flake8 +++ b/.flake8 @@ -24,6 +24,7 @@ exclude = **/data/**, icefall/shared/make_kn_lm.py, icefall/__init__.py + icefall/ctc/__init__.py ignore = # E203 white space before ":" diff --git a/.github/scripts/run-pre-trained-conformer-ctc.sh b/.github/scripts/run-pre-trained-conformer-ctc.sh deleted file mode 100755 index a4959aa01..000000000 --- a/.github/scripts/run-pre-trained-conformer-ctc.sh +++ /dev/null @@ -1,46 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500 -git lfs install - -log "Downloading pre-trained model from $repo_url" -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.flac - -log "CTC decoding" - -./conformer_ctc/pretrained.py \ - --method ctc-decoding \ - --num-classes 500 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.flac \ - $repo/test_wavs/1221-135766-0001.flac \ - $repo/test_wavs/1221-135766-0002.flac - -log "HLG decoding" - -./conformer_ctc/pretrained.py \ - --method 1best \ - --num-classes 500 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --words-file $repo/data/lang_bpe_500/words.txt \ - --HLG $repo/data/lang_bpe_500/HLG.pt \ - $repo/test_wavs/1089-134686-0001.flac \ - $repo/test_wavs/1221-135766-0001.flac \ - $repo/test_wavs/1221-135766-0002.flac diff --git a/.github/scripts/run-pre-trained-ctc.sh b/.github/scripts/run-pre-trained-ctc.sh new file mode 100755 index 000000000..7d6449c9a --- /dev/null +++ b/.github/scripts/run-pre-trained-ctc.sh @@ -0,0 +1,240 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +pushd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02 +log "Downloading pre-trained model from $repo_url" +git lfs install +git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + +log "CTC greedy search" + +./zipformer/onnx_pretrained_ctc.py \ + --nn-model $repo/model.onnx \ + --tokens $repo/tokens.txt \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + +log "CTC H decoding" + +./zipformer/onnx_pretrained_ctc_H.py \ + --nn-model $repo/model.onnx \ + --tokens $repo/tokens.txt \ + --H $repo/H.fst \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + +log "CTC HL decoding" + +./zipformer/onnx_pretrained_ctc_HL.py \ + --nn-model $repo/model.onnx \ + --words $repo/words.txt \ + --HL $repo/HL.fst \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + +log "CTC HLG decoding" + +./zipformer/onnx_pretrained_ctc_HLG.py \ + --nn-model $repo/model.onnx \ + --words $repo/words.txt \ + --HLG $repo/HLG.fst \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + +rm -rf $repo + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-conformer-ctc-jit-bpe-500-2021-11-09 +log "Downloading pre-trained model from $repo_url" +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) +pushd $repo + +git lfs pull --include "exp/pretrained.pt" +git lfs pull --include "data/lang_bpe_500/HLG.pt" +git lfs pull --include "data/lang_bpe_500/L.pt" +git lfs pull --include "data/lang_bpe_500/L_disambig.pt" +git lfs pull --include "data/lang_bpe_500/Linv.pt" +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "data/lang_bpe_500/lexicon.txt" +git lfs pull --include "data/lang_bpe_500/lexicon_disambig.txt" +git lfs pull --include "data/lang_bpe_500/tokens.txt" +git lfs pull --include "data/lang_bpe_500/words.txt" +git lfs pull --include "data/lm/G_3_gram.fst.txt" + +popd + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + +log "CTC decoding" + +./conformer_ctc/pretrained.py \ + --method ctc-decoding \ + --num-classes 500 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +log "HLG decoding" + +./conformer_ctc/pretrained.py \ + --method 1best \ + --num-classes 500 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +log "CTC decoding on CPU with kaldi decoders using OpenFst" + +log "Exporting model with torchscript" + +pushd $repo/exp +ln -s pretrained.pt epoch-99.pt +popd + +./conformer_ctc/export.py \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --jit 1 + +ls -lh $repo/exp + + +log "Generating H.fst, HL.fst" + +./local/prepare_lang_fst.py --lang-dir $repo/data/lang_bpe_500 --ngram-G $repo/data/lm/G_3_gram.fst.txt + +ls -lh $repo/data/lang_bpe_500 + +log "Decoding with H on CPU with OpenFst" + +./conformer_ctc/jit_pretrained_decode_with_H.py \ + --nn-model $repo/exp/cpu_jit.pt \ + --H $repo/data/lang_bpe_500/H.fst \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +log "Decoding with HL on CPU with OpenFst" + +./conformer_ctc/jit_pretrained_decode_with_HL.py \ + --nn-model $repo/exp/cpu_jit.pt \ + --HL $repo/data/lang_bpe_500/HL.fst \ + --words $repo/data/lang_bpe_500/words.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +log "Decoding with HLG on CPU with OpenFst" + +./conformer_ctc/jit_pretrained_decode_with_HLG.py \ + --nn-model $repo/exp/cpu_jit.pt \ + --HLG $repo/data/lang_bpe_500/HLG.fst \ + --words $repo/data/lang_bpe_500/words.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +rm -rf $repo + +popd + +log "Test aishell" + +pushd egs/aishell/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall_asr_aishell_conformer_ctc +log "Downloading pre-trained model from $repo_url" +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) +pushd $repo + +git lfs pull --include "exp/pretrained.pt" +git lfs pull --include "data/lang_char/H.fst" +git lfs pull --include "data/lang_char/HL.fst" +git lfs pull --include "data/lang_char/HLG.fst" + +popd + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + +log "CTC decoding" + +log "Exporting model with torchscript" + +pushd $repo/exp +ln -s pretrained.pt epoch-99.pt +popd + +./conformer_ctc/export.py \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --tokens $repo/data/lang_char/tokens.txt \ + --jit 1 + +ls -lh $repo/exp + +ls -lh $repo/data/lang_char + +log "Decoding with H on CPU with OpenFst" + +./conformer_ctc/jit_pretrained_decode_with_H.py \ + --nn-model $repo/exp/cpu_jit.pt \ + --H $repo/data/lang_char/H.fst \ + --tokens $repo/data/lang_char/tokens.txt \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + +log "Decoding with HL on CPU with OpenFst" + +./conformer_ctc/jit_pretrained_decode_with_HL.py \ + --nn-model $repo/exp/cpu_jit.pt \ + --HL $repo/data/lang_char/HL.fst \ + --words $repo/data/lang_char/words.txt \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + +log "Decoding with HLG on CPU with OpenFst" + +./conformer_ctc/jit_pretrained_decode_with_HLG.py \ + --nn-model $repo/exp/cpu_jit.pt \ + --HLG $repo/data/lang_char/HLG.fst \ + --words $repo/data/lang_char/words.txt \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + +rm -rf $repo diff --git a/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh b/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh new file mode 100755 index 000000000..d8cc020e1 --- /dev/null +++ b/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/swbd/ASR + +repo_url=https://huggingface.co/zrjin/icefall-asr-swbd-conformer-ctc-2023-8-26 + +log "Downloading pre-trained model from $repo_url" +git lfs install +git clone $repo_url +repo=$(basename $repo_url) + + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +ln -s epoch-98.pt epoch-99.pt +popd + +ls -lh $repo/exp/*.pt + +for method in ctc-decoding 1best; do + log "$method" + + ./conformer_ctc/pretrained.py \ + --method $method \ + --checkpoint $repo/exp/epoch-99.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --G $repo/data/lm/G_4_gram.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done diff --git a/.github/workflows/run-pretrained-conformer-ctc.yml b/.github/workflows/run-pretrained-ctc.yml similarity index 86% rename from .github/workflows/run-pretrained-conformer-ctc.yml rename to .github/workflows/run-pretrained-ctc.yml index 6151a5a14..074a63dfc 100644 --- a/.github/workflows/run-pretrained-conformer-ctc.yml +++ b/.github/workflows/run-pretrained-ctc.yml @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -name: run-pre-trained-conformer-ctc +name: run-pre-trained-ctc on: push: @@ -23,13 +23,20 @@ on: pull_request: types: [labeled] + workflow_dispatch: + inputs: + test-run: + description: 'Test (y/n)?' + required: true + default: 'y' + concurrency: - group: run_pre_trained_conformer_ctc-${{ github.ref }} + group: run_pre_trained_ctc-${{ github.ref }} cancel-in-progress: true jobs: - run_pre_trained_conformer_ctc: - if: github.event.label.name == 'ready' || github.event_name == 'push' + run_pre_trained_ctc: + if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event.inputs.test-run == 'y' || github.event.label.name == 'ctc' runs-on: ${{ matrix.os }} strategy: matrix: @@ -77,4 +84,4 @@ jobs: export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - .github/scripts/run-pre-trained-conformer-ctc.sh + .github/scripts/run-pre-trained-ctc.sh diff --git a/.github/workflows/run-swbd-conformer-ctc.yml b/.github/workflows/run-swbd-conformer-ctc.yml new file mode 100644 index 000000000..842691d38 --- /dev/null +++ b/.github/workflows/run-swbd-conformer-ctc.yml @@ -0,0 +1,84 @@ +# Copyright 2023 Xiaomi Corp. (author: Zengrui Jin) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-swbd-conformer_ctc + +on: + push: + branches: + - master + pull_request: + types: [labeled] + +concurrency: + group: run-swbd-conformer_ctc-${{ github.ref }} + cancel-in-progress: true + +jobs: + run-swbd-conformer_ctc: + if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'swbd' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf==3.20.* + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2023-05-22 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + sudo apt-get -qq install git-lfs tree + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-swbd-conformer-ctc-2023-08-26.sh diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml index 57f15fe87..7d55a50e1 100644 --- a/.github/workflows/run-yesno-recipe.yml +++ b/.github/workflows/run-yesno-recipe.yml @@ -60,7 +60,7 @@ jobs: - name: Install Python dependencies run: | - grep -v '^#' ./requirements-ci.txt | grep -v kaldifst | xargs -n 1 -L 1 pip install + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf pip install --no-binary protobuf protobuf==3.20.* @@ -140,9 +140,46 @@ jobs: download/waves_yesno/0_0_0_1_0_0_0_1.wav \ download/waves_yesno/0_0_1_0_0_0_1_0.wav + - name: Test decoding with H + shell: bash + working-directory: ${{github.workspace}} + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + echo $PYTHONPATH + + cd egs/yesno/ASR + python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1 + + python3 ./tdnn/jit_pretrained_decode_with_H.py \ + --nn-model ./tdnn/exp/cpu_jit.pt \ + --H ./data/lang_phone/H.fst \ + --tokens ./data/lang_phone/tokens.txt \ + ./download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + ./download/waves_yesno/0_0_1_0_0_0_1_0.wav \ + ./download/waves_yesno/0_0_1_0_0_1_1_1.wav + + - name: Test decoding with HL + shell: bash + working-directory: ${{github.workspace}} + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + echo $PYTHONPATH + + cd egs/yesno/ASR + python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1 + + python3 ./tdnn/jit_pretrained_decode_with_HL.py \ + --nn-model ./tdnn/exp/cpu_jit.pt \ + --HL ./data/lang_phone/HL.fst \ + --words ./data/lang_phone/words.txt \ + ./download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + ./download/waves_yesno/0_0_1_0_0_0_1_0.wav \ + ./download/waves_yesno/0_0_1_0_0_1_1_1.wav + - name: Show generated files shell: bash working-directory: ${{github.workspace}} run: | cd egs/yesno/ASR ls -lh tdnn/exp + ls -lh data/lang_phone diff --git a/.gitignore b/.gitignore index 8af05d884..fa18ca83c 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,5 @@ node_modules *.param *.bin .DS_Store +*.fst +*.arpa diff --git a/README.md b/README.md index a876fb24e..da446109d 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ We provide the following recipes: - [yesno][yesno] - [LibriSpeech][librispeech] - [GigaSpeech][gigaspeech] + - [AMI][ami] - [Aishell][aishell] - [Aishell2][aishell2] - [Aishell4][aishell4] @@ -37,6 +38,7 @@ We provide the following recipes: - [Aidatatang_200zh][aidatatang_200zh] - [WenetSpeech][wenetspeech] - [Alimeeting][alimeeting] + - [Switchboard][swbd] - [TAL_CSASR][tal_csasr] ### yesno @@ -118,9 +120,9 @@ We provide a Colab notebook to run a pre-trained transducer conformer + stateles | Encoder | Params | test-clean | test-other | |-----------------|--------|------------|------------| -| zipformer | 65.5M | 2.21 | 4.91 | -| zipformer-small | 23.2M | 2.46 | 5.83 | -| zipformer-large | 148.4M | 2.11 | 4.77 | +| zipformer | 65.5M | 2.21 | 4.79 | +| zipformer-small | 23.2M | 2.42 | 5.73 | +| zipformer-large | 148.4M | 2.06 | 4.63 | Note: No auxiliary losses are used in the training and no LMs are used in the decoding. @@ -338,7 +340,7 @@ We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder #### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss -The best results for Chinese CER(%) and English WER(%) respectivly (zh: Chinese, en: English): +The best results for Chinese CER(%) and English WER(%) respectively (zh: Chinese, en: English): |decoding-method | dev | dev_zh | dev_en | test | test_zh | test_en | |--|--|--|--|--|--|--| |greedy_search| 7.30 | 6.48 | 19.19 |7.39| 6.66 | 19.13| @@ -393,4 +395,6 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad [wenetspeech]: egs/wenetspeech/ASR [alimeeting]: egs/alimeeting/ASR [tal_csasr]: egs/tal_csasr/ASR +[ami]: egs/ami +[swbd]: egs/swbd/ASR [k2]: https://github.com/k2-fsa/k2 diff --git a/docs/source/decoding-with-langugage-models/index.rst b/docs/source/decoding-with-langugage-models/index.rst index 6e5e3a4d9..c49da9a4e 100644 --- a/docs/source/decoding-with-langugage-models/index.rst +++ b/docs/source/decoding-with-langugage-models/index.rst @@ -2,12 +2,13 @@ Decoding with language models ============================= This section describes how to use external langugage models -during decoding to improve the WER of transducer models. +during decoding to improve the WER of transducer models. To train an external language model, +please refer to this tutorial: :ref:`train_nnlm`. The following decoding methods with external langugage models are available: -.. list-table:: LM-rescoring-based methods vs shallow-fusion-based methods (The numbers in each field is WER on test-clean, WER on test-other and decoding time on test-clean) +.. list-table:: :widths: 25 50 :header-rows: 1 diff --git a/docs/source/model-export/export-ncnn.rst b/docs/source/model-export/export-ncnn.rst index 9eb5f85d2..634fb1e59 100644 --- a/docs/source/model-export/export-ncnn.rst +++ b/docs/source/model-export/export-ncnn.rst @@ -1,3 +1,5 @@ +.. _icefall_export_to_ncnn: + Export to ncnn ============== diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst index 2e8d0893a..37edf7de9 100644 --- a/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst @@ -47,7 +47,7 @@ The data preparation contains several stages, you can use the following two options: - ``--stage`` - - ``--stop-stage`` + - ``--stop_stage`` to control which stage(s) should be run. By default, all stages are executed. @@ -56,8 +56,8 @@ For example, .. code-block:: bash $ cd egs/librispeech/ASR - $ ./prepare.sh --stage 0 --stop-stage 0 # run only stage 0 - $ ./prepare.sh --stage 2 --stop-stage 5 # run from stage 2 to stage 5 + $ ./prepare.sh --stage 0 --stop_stage 0 # run only stage 0 + $ ./prepare.sh --stage 2 --stop_stage 5 # run from stage 2 to stage 5 .. HINT:: @@ -108,15 +108,15 @@ As usual, you can control the stages you want to run by specifying the following two options: - ``--stage`` - - ``--stop-stage`` + - ``--stop_stage`` For example, .. code-block:: bash $ cd egs/librispeech/ASR - $ ./distillation_with_hubert.sh --stage 0 --stop-stage 0 # run only stage 0 - $ ./distillation_with_hubert.sh --stage 2 --stop-stage 4 # run from stage 2 to stage 5 + $ ./distillation_with_hubert.sh --stage 0 --stop_stage 0 # run only stage 0 + $ ./distillation_with_hubert.sh --stage 2 --stop_stage 4 # run from stage 2 to stage 5 Here are a few options in `./distillation_with_hubert.sh `_ you need to know before you proceed. @@ -134,7 +134,7 @@ and prepares MVQ-augmented training manifests. .. code-block:: bash - $ ./distillation_with_hubert.sh --stage 2 --stop-stage 2 # run only stage 2 + $ ./distillation_with_hubert.sh --stage 2 --stop_stage 2 # run only stage 2 Please see the following screenshot for the output of an example execution. @@ -172,7 +172,7 @@ To perform training, please run stage 3 by executing the following command. .. code-block:: bash - $ ./prepare.sh --stage 3 --stop-stage 3 # run MVQ training + $ ./prepare.sh --stage 3 --stop_stage 3 # run MVQ training Here is the code snippet for training: diff --git a/docs/source/recipes/RNN-LM/index.rst b/docs/source/recipes/RNN-LM/index.rst new file mode 100644 index 000000000..4b74e64c7 --- /dev/null +++ b/docs/source/recipes/RNN-LM/index.rst @@ -0,0 +1,7 @@ +RNN-LM +====== + +.. toctree:: + :maxdepth: 2 + + librispeech/lm-training \ No newline at end of file diff --git a/docs/source/recipes/RNN-LM/librispeech/lm-training.rst b/docs/source/recipes/RNN-LM/librispeech/lm-training.rst new file mode 100644 index 000000000..736120275 --- /dev/null +++ b/docs/source/recipes/RNN-LM/librispeech/lm-training.rst @@ -0,0 +1,104 @@ +.. _train_nnlm: + +Train an RNN langugage model +====================================== + +If you have enough text data, you can train a neural network language model (NNLM) to improve +the WER of your E2E ASR system. This tutorial shows you how to train an RNNLM from +scratch. + +.. HINT:: + + For how to use an NNLM during decoding, please refer to the following tutorials: + :ref:`shallow_fusion`, :ref:`LODR`, :ref:`rescoring` + +.. note:: + + This tutorial is based on the LibriSpeech recipe. Please check it out for the necessary + python scripts for this tutorial. We use the LibriSpeech LM-corpus as the LM training set + for illustration purpose. You can also collect your own data. The data format is quite simple: + each line should contain a complete sentence, and words should be separated by space. + +First, let's download the training data for the RNNLM. This can be done via the +following command: + +.. code-block:: bash + + $ wget https://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz + $ gzip -d librispeech-lm-norm.txt.gz + +As we are training a BPE-level RNNLM, we need to tokenize the training text, which requires a +BPE tokenizer. This can be achieved by executing the following command: + +.. code-block:: bash + + $ # if you don't have the BPE + $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 + $ cd icefall-asr-librispeech-zipformer-2023-05-15/data/lang_bpe_500 + $ git lfs pull --include bpe.model + $ cd ../../.. + + $ ./local/prepare_lm_training_data.py \ + --bpe-model icefall-asr-librispeech-zipformer-2023-05-15/data/lang_bpe_500/bpe.model \ + --lm-data librispeech-lm-norm.txt \ + --lm-archive data/lang_bpe_500/lm_data.pt + +Now, you should have a file name ``lm_data.pt`` file store under the directory ``data/lang_bpe_500``. +This is the packed training data for the RNNLM. We then sort the training data according to its +sentence length. + +.. code-block:: bash + + $ # This could take a while (~ 20 minutes), feel free to grab a cup of coffee :) + $ ./local/sort_lm_training_data.py \ + --in-lm-data data/lang_bpe_500/lm_data.pt \ + --out-lm-data data/lang_bpe_500/sorted_lm_data.pt \ + --out-statistics data/lang_bpe_500/lm_data_stats.txt + + +The aforementioned steps can be repeated to create a a validation set for you RNNLM. Let's say +you have a validation set in ``valid.txt``, you can just set ``--lm-data valid.txt`` +and ``--lm-archive data/lang_bpe_500/lm-data-valid.pt`` when calling ``./local/prepare_lm_training_data.py``. + +After completing the previous steps, the training and testing sets for training RNNLM are ready. +The next step is to train the RNNLM model. The training command is as follows: + +.. code-block:: bash + + $ # assume you are in the icefall root directory + $ cd rnn_lm + $ ln -s ../../egs/librispeech/ASR/data . + $ cd .. + $ ./rnn_lm/train.py \ + --world-size 4 \ + --exp-dir ./rnn_lm/exp \ + --start-epoch 0 \ + --num-epochs 10 \ + --use-fp16 0 \ + --tie-weights 1 \ + --embedding-dim 2048 \ + --hidden_dim 2048 \ + --num-layers 3 \ + --batch-size 300 \ + --lm-data rnn_lm/data/lang_bpe_500/sorted_lm_data.pt \ + --lm-data-valid rnn_lm/data/lang_bpe_500/sorted_lm_data.pt + + +.. note:: + + You can adjust the RNNLM hyper parameters to control the size of the RNNLM, + such as embedding dimension and hidden state dimension. For more details, please + run ``./rnn_lm/train.py --help``. + +.. note:: + + The training of RNNLM can take a long time (usually a couple of days). + + + + + + + + + diff --git a/docs/source/recipes/index.rst b/docs/source/recipes/index.rst index 63793275c..7265e1cf6 100644 --- a/docs/source/recipes/index.rst +++ b/docs/source/recipes/index.rst @@ -15,3 +15,4 @@ We may add recipes for other tasks as well in the future. Non-streaming-ASR/index Streaming-ASR/index + RNN-LM/index diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py index c9d9c4aa8..fa809b768 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py @@ -635,7 +635,6 @@ def train_one_epoch( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -800,7 +799,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell/ASR/conformer_ctc/export.py b/egs/aishell/ASR/conformer_ctc/export.py old mode 100644 new mode 100755 index 1df3cfdc2..49871d437 --- a/egs/aishell/ASR/conformer_ctc/export.py +++ b/egs/aishell/ASR/conformer_ctc/export.py @@ -23,12 +23,12 @@ import argparse import logging from pathlib import Path +import k2 import torch from conformer import Conformer from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -63,11 +63,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="""It contains language related input files such as "lexicon.txt" - """, + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -98,16 +97,16 @@ def get_params() -> AttributeDict: def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) params = get_params() params.update(vars(args)) - logging.info(params) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank + num_classes = num_tokens(token_table) + 1 # +1 for the blank + + logging.info(params) device = torch.device("cpu") if torch.cuda.is_available(): diff --git a/egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_H.py b/egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_H.py new file mode 120000 index 000000000..896b78aef --- /dev/null +++ b/egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_H.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py \ No newline at end of file diff --git a/egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py b/egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py new file mode 120000 index 000000000..aa1b6073d --- /dev/null +++ b/egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py \ No newline at end of file diff --git a/egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py b/egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py new file mode 120000 index 000000000..0cf42ce30 --- /dev/null +++ b/egs/aishell/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py \ No newline at end of file diff --git a/egs/aishell/ASR/conformer_ctc/test_transformer.py b/egs/aishell/ASR/conformer_ctc/test_transformer.py old mode 100644 new mode 100755 diff --git a/egs/aishell/ASR/local/prepare_lang_fst.py b/egs/aishell/ASR/local/prepare_lang_fst.py new file mode 120000 index 000000000..c5787c534 --- /dev/null +++ b/egs/aishell/ASR/local/prepare_lang_fst.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_fst.py \ No newline at end of file diff --git a/egs/aishell/ASR/local/train_bbpe_model.py b/egs/aishell/ASR/local/train_bbpe_model.py index d231d5d77..48160897d 100755 --- a/egs/aishell/ASR/local/train_bbpe_model.py +++ b/egs/aishell/ASR/local/train_bbpe_model.py @@ -15,7 +15,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - # You can install sentencepiece via: # # pip install sentencepiece @@ -26,12 +25,12 @@ # Please install a version >=0.1.96 import argparse -import re import shutil import tempfile from pathlib import Path import sentencepiece as spm + from icefall import byte_encode, tokenize_by_CJK_char @@ -74,6 +73,11 @@ def main(): model_type = "unigram" model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" + model_file = Path(model_prefix + ".model") + if model_file.is_file(): + print(f"{model_file} exists - skipping") + return + character_coverage = 1.0 input_sentence_size = 100000000 @@ -88,23 +92,18 @@ def main(): _convert_to_bchar(args.transcript, train_text) - model_file = Path(model_prefix + ".model") - if not model_file.is_file(): - spm.SentencePieceTrainer.train( - input=train_text, - vocab_size=vocab_size, - model_type=model_type, - model_prefix=model_prefix, - input_sentence_size=input_sentence_size, - character_coverage=character_coverage, - user_defined_symbols=user_defined_symbols, - unk_id=unk_id, - bos_id=-1, - eos_id=-1, - ) - else: - print(f"{model_file} exists - skipping") - return + spm.SentencePieceTrainer.train( + input=train_text, + vocab_size=vocab_size, + model_type=model_type, + model_prefix=model_prefix, + input_sentence_size=input_sentence_size, + character_coverage=character_coverage, + user_defined_symbols=user_defined_symbols, + unk_id=unk_id, + bos_id=-1, + eos_id=-1, + ) shutil.copyfile(model_file, f"{lang_dir}/bbpe.model") diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index ff8e1301d..9de060e73 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -143,6 +143,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then ./local/prepare_lang.py --lang-dir $lang_phone_dir fi + # Train a bigram P for MMI training if [ ! -f $lang_phone_dir/transcript_words.txt ]; then log "Generate data to train phone based bigram P" @@ -203,6 +204,10 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then if [ ! -f $lang_char_dir/L_disambig.pt ]; then ./local/prepare_char.py --lang-dir $lang_char_dir fi + + if [ ! -f $lang_char_dir/HLG.fst ]; then + ./local/prepare_lang_fst.py --lang-dir $lang_phone_dir --ngram-G ./data/lm/G_3_gram.fst.txt + fi fi if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/train.py b/egs/aishell/ASR/pruned_transducer_stateless2/train.py index d08908238..60f014c48 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/train.py @@ -872,7 +872,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/train.py b/egs/aishell/ASR/pruned_transducer_stateless3/train.py index 62e67530d..7c23041ca 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/train.py @@ -1045,7 +1045,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py index e8211500a..2a9fc57d5 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py @@ -322,6 +322,7 @@ def export_decoder_model_onnx( vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/onnx_pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7/onnx_pretrained.py index 5adb6c16a..a92182e8d 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/onnx_pretrained.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/onnx_pretrained.py @@ -151,12 +151,14 @@ class OnnxModel: self.encoder = ort.InferenceSession( encoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) def init_decoder(self, decoder_model_filename: str): self.decoder = ort.InferenceSession( decoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) decoder_meta = self.decoder.get_modelmeta().custom_metadata_map @@ -170,6 +172,7 @@ class OnnxModel: self.joiner = ort.InferenceSession( joiner_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) joiner_meta = self.joiner.get_modelmeta().custom_metadata_map diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/train.py b/egs/aishell/ASR/pruned_transducer_stateless7/train.py index cbb7db086..11671db92 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/train.py @@ -1028,7 +1028,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/train2.py b/egs/aishell/ASR/pruned_transducer_stateless7/train2.py index c30f6f960..057af297f 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/train2.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/train2.py @@ -1031,7 +1031,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py index 4e52f9573..3858bafd7 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py @@ -1019,7 +1019,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py index 180930747..6abe6c084 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -198,7 +198,7 @@ class AishellAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py index 74bf68ccb..8c7448d4c 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py @@ -730,7 +730,6 @@ def train_one_epoch( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -919,7 +918,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py index 47015cbe7..a354f761e 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py @@ -908,7 +908,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py index e57b5c859..30154291d 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py @@ -635,7 +635,6 @@ def train_one_epoch( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -800,7 +799,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py index 45d777922..8f09f1aa5 100755 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py @@ -999,7 +999,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/ami/ASR/pruned_transducer_stateless7/train.py b/egs/ami/ASR/pruned_transducer_stateless7/train.py index 8c8d9593b..9b67141c0 100755 --- a/egs/ami/ASR/pruned_transducer_stateless7/train.py +++ b/egs/ami/ASR/pruned_transducer_stateless7/train.py @@ -988,7 +988,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/export-onnx.py index 0c98885ac..2b9f2293a 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/export-onnx.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/export-onnx.py @@ -330,6 +330,7 @@ def export_decoder_model_onnx( vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py index eee19191e..cf6ddfa36 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py @@ -152,12 +152,14 @@ class OnnxModel: self.encoder = ort.InferenceSession( encoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) def init_decoder(self, decoder_model_filename: str): self.decoder = ort.InferenceSession( decoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) decoder_meta = self.decoder.get_modelmeta().custom_metadata_map @@ -171,6 +173,7 @@ class OnnxModel: self.joiner = ort.InferenceSession( joiner_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) joiner_meta = self.joiner.get_modelmeta().custom_metadata_map diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py index 4bd5b83a2..4aedeffe4 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py @@ -1019,7 +1019,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py index 18cb75c37..73fcd67aa 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py @@ -1074,7 +1074,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py index bc4bcf253..4c866ddd8 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py @@ -1075,7 +1075,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index b945f43fd..fc7fcdc26 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -75,7 +75,7 @@ See for more details. ##### normal-scaled model, number of model parameters: 65549011, i.e., 65.55 M The tensorboard log can be found at - + You can find a pretrained model, training logs, decoding logs, and decoding results at: @@ -90,18 +90,20 @@ You can use to deploy it. | greedy_search | 2.23 | 4.96 | --epoch 40 --avg 16 | | modified_beam_search | 2.21 | 4.91 | --epoch 40 --avg 16 | | fast_beam_search | 2.24 | 4.93 | --epoch 40 --avg 16 | +| greedy_search | 2.22 | 4.87 | --epoch 50 --avg 25 | +| modified_beam_search | 2.21 | 4.79 | --epoch 50 --avg 25 | +| fast_beam_search | 2.21 | 4.82 | --epoch 50 --avg 25 | | modified_beam_search_shallow_fusion | 2.01 | 4.37 | --epoch 40 --avg 16 --beam-size 12 --lm-scale 0.3 | | modified_beam_search_LODR | 1.94 | 4.17 | --epoch 40 --avg 16 --beam-size 12 --lm-scale 0.52 --LODR-scale -0.26 | | modified_beam_search_rescore | 2.04 | 4.39 | --epoch 40 --avg 16 --beam-size 12 | | modified_beam_search_rescore_LODR | 2.01 | 4.33 | --epoch 40 --avg 16 --beam-size 12 | - The training command is: ```bash export CUDA_VISIBLE_DEVICES="0,1,2,3" ./zipformer/train.py \ --world-size 4 \ - --num-epochs 40 \ + --num-epochs 50 \ --start-epoch 1 \ --use-fp16 1 \ --exp-dir zipformer/exp \ @@ -115,8 +117,8 @@ The decoding command is: export CUDA_VISIBLE_DEVICES="0" for m in greedy_search modified_beam_search fast_beam_search; do ./zipformer/decode.py \ - --epoch 30 \ - --avg 9 \ + --epoch 50 \ + --avg 25 \ --use-averaged-model 1 \ --exp-dir ./zipformer/exp \ --max-duration 600 \ @@ -129,7 +131,7 @@ To decode with external language models, please refer to the documentation [here ##### small-scaled model, number of model parameters: 23285615, i.e., 23.3 M The tensorboard log can be found at - + You can find a pretrained model, training logs, decoding logs, and decoding results at: @@ -144,13 +146,16 @@ You can use to deploy it. | greedy_search | 2.49 | 5.91 | --epoch 40 --avg 13 | | modified_beam_search | 2.46 | 5.83 | --epoch 40 --avg 13 | | fast_beam_search | 2.46 | 5.87 | --epoch 40 --avg 13 | +| greedy_search | 2.46 | 5.86 | --epoch 50 --avg 23 | +| modified_beam_search | 2.42 | 5.73 | --epoch 50 --avg 23 | +| fast_beam_search | 2.46 | 5.78 | --epoch 50 --avg 23 | The training command is: ```bash export CUDA_VISIBLE_DEVICES="0,1" ./zipformer/train.py \ --world-size 2 \ - --num-epochs 40 \ + --num-epochs 50 \ --start-epoch 1 \ --use-fp16 1 \ --exp-dir zipformer/exp-small \ @@ -169,8 +174,8 @@ The decoding command is: export CUDA_VISIBLE_DEVICES="0" for m in greedy_search modified_beam_search fast_beam_search; do ./zipformer/decode.py \ - --epoch 40 \ - --avg 13 \ + --epoch 50 \ + --avg 23 \ --exp-dir zipformer/exp-small \ --max-duration 600 \ --causal 0 \ @@ -185,7 +190,7 @@ done ##### large-scaled model, number of model parameters: 148439574, i.e., 148.4 M The tensorboard log can be found at - + You can find a pretrained model, training logs, decoding logs, and decoding results at: @@ -200,13 +205,16 @@ You can use to deploy it. | greedy_search | 2.12 | 4.8 | --epoch 40 --avg 13 | | modified_beam_search | 2.11 | 4.7 | --epoch 40 --avg 13 | | fast_beam_search | 2.13 | 4.78 | --epoch 40 --avg 13 | +| greedy_search | 2.08 | 4.69 | --epoch 50 --avg 30 | +| modified_beam_search | 2.06 | 4.63 | --epoch 50 --avg 30 | +| fast_beam_search | 2.09 | 4.68 | --epoch 50 --avg 30 | The training command is: ```bash export CUDA_VISIBLE_DEVICES="0,1,2,3" ./zipformer/train.py \ --world-size 4 \ - --num-epochs 40 \ + --num-epochs 50 \ --start-epoch 1 \ --use-fp16 1 \ --exp-dir zipformer/exp-large \ @@ -224,8 +232,8 @@ The decoding command is: export CUDA_VISIBLE_DEVICES="0" for m in greedy_search modified_beam_search fast_beam_search; do ./zipformer/decode.py \ - --epoch 40 \ - --avg 16 \ + --epoch 50 \ + --avg 30 \ --exp-dir zipformer/exp-large \ --max-duration 600 \ --causal 0 \ diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py new file mode 100755 index 000000000..4bdec9e11 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py @@ -0,0 +1,247 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This file shows how to use a torchscript model for decoding with H +on CPU using OpenFST and decoders from kaldi. + +Usage: + +(1) LibriSpeech conformer_ctc + + ./conformer_ctc/jit_pretrained_decode_with_H.py \ + --nn-model ./conformer_ctc/exp/cpu_jit.pt \ + --H ./data/lang_bpe_500/H.fst \ + --tokens ./data/lang_bpe_500/tokens.txt \ + ./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \ + ./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac + + +(2) AIShell conformer_ctc + + ./conformer_ctc/jit_pretrained_decode_with_H.py \ + --nn-model ./conformer_ctc/exp/cpu_jit.pt \ + --H ./data/lang_char/H.fst \ + --tokens ./data/lang_char/tokens.txt \ + ./BAC009S0764W0121.wav \ + ./BAC009S0764W0122.wav \ + ./BAC009S0764W0123.wav + +Note that to generate ./conformer_ctc/exp/cpu_jit.pt, +you can use ./export.py --jit 1 +""" + +import argparse +import logging +import math +from typing import Dict, List + +import kaldifeat +import kaldifst +import torch +import torchaudio +from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="""Path to the torchscript model. + You can use ./conformer_ctc/export.py --jit 1 + to obtain it + """, + ) + + parser.add_argument( + "--tokens", + type=str, + required=True, + help="Path to tokens.txt", + ) + + parser.add_argument("--H", type=str, required=True, help="Path to H.fst") + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. ", + ) + + return parser + + +def read_tokens(tokens_txt: str) -> Dict[int, str]: + id2token = dict() + with open(tokens_txt, encoding="utf-8") as f: + for line in f: + token, idx = line.strip().split() + id2token[int(idx)] = token + + return id2token + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + if sample_rate != expected_sample_rate: + wave = torchaudio.functional.resample( + wave, + orig_freq=sample_rate, + new_freq=expected_sample_rate, + ) + + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + nnet_output: torch.Tensor, + H: kaldifst, + id2token: Dict[int, str], +) -> List[str]: + """ + Args: + filename: + Path to the filename for decoding. Used for debugging. + nnet_output: + A 2-D float32 tensor of shape (num_frames, vocab_size). It + contains output from log_softmax. + H: + The H graph. + id2token: + A map mapping token ID to token string. + Returns: + Return a list of decoded tokens. + """ + logging.info(f"{filename}, {nnet_output.shape}") + decodable = DecodableCtc(nnet_output.cpu()) + + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(H, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + logging.info(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + logging.info(f"failed to get linear symbol sequence for {filename}") + return [""] + + # tokens are incremented during graph construction + # so they need to be decremented + hyps = [id2token[i - 1] for i in osymbols_out] + # hyps = "".join(hyps).split("▁") + hyps = "".join(hyps).split("\u2581") # unicode codepoint of ▁ + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + logging.info("Loading torchscript model") + model = torch.jit.load(args.nn_model) + model.eval() + model.to(device) + + logging.info(f"Loading H from {args.H}") + H = kaldifst.StdVectorFst.read(args.H) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, expected_sample_rate=sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.shape[0] for f in features] + feature_lengths = torch.tensor(feature_lengths) + + supervisions = dict() + supervisions["sequence_idx"] = torch.arange(len(features)) + supervisions["start_frame"] = torch.zeros(len(features)) + supervisions["num_frames"] = feature_lengths + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + nnet_output, _, _ = model(features, supervisions) + feature_lengths = ((feature_lengths - 1) // 2 - 1) // 2 + + id2token = read_tokens(args.tokens) + + hyps = [] + for i in range(nnet_output.shape[0]): + hyp = decode( + filename=args.sound_files[i], + nnet_output=nnet_output[i, : feature_lengths[i]], + H=H, + id2token=id2token, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py new file mode 100755 index 000000000..d5a1dba3c --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This file shows how to use a torchscript model for decoding with HL +on CPU using OpenFST and decoders from kaldi. + +Usage: + +(1) LibriSpeech conformer_ctc + + ./conformer_ctc/jit_pretrained_decode_with_HL.py \ + --nn-model ./conformer_ctc/exp/cpu_jit.pt \ + --HL ./data/lang_bpe_500/HL.fst \ + --words ./data/lang_bpe_500/words.txt \ + ./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \ + ./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac + +(2) AIShell conformer_ctc + + ./conformer_ctc/jit_pretrained_decode_with_HL.py \ + --nn-model ./conformer_ctc/exp/cpu_jit.pt \ + --HL ./data/lang_char/HL.fst \ + --words ./data/lang_char/words.txt \ + ./BAC009S0764W0121.wav \ + ./BAC009S0764W0122.wav \ + ./BAC009S0764W0123.wav + + +Note that to generate ./conformer_ctc/exp/cpu_jit.pt, +you can use ./export.py --jit 1 +""" + +import argparse +import logging +import math +from typing import Dict, List + +import kaldifeat +import kaldifst +import torch +import torchaudio +from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="""Path to the torchscript model. + You can use ./conformer_ctc/export.py --jit 1 + to obtain it + """, + ) + + parser.add_argument( + "--words", + type=str, + required=True, + help="Path to words.txt", + ) + + parser.add_argument("--HL", type=str, required=True, help="Path to HL.fst") + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. ", + ) + + return parser + + +def read_words(words_txt: str) -> Dict[int, str]: + id2word = dict() + with open(words_txt, encoding="utf-8") as f: + for line in f: + word, idx = line.strip().split() + id2word[int(idx)] = word + + return id2word + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + if sample_rate != expected_sample_rate: + wave = torchaudio.functional.resample( + wave, + orig_freq=sample_rate, + new_freq=expected_sample_rate, + ) + + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + nnet_output: torch.Tensor, + HL: kaldifst, + id2word: Dict[int, str], +) -> List[str]: + """ + Args: + filename: + Path to the filename for decoding. Used for debugging. + nnet_output: + A 2-D float32 tensor of shape (num_frames, vocab_size). It + contains output from log_softmax. + HL: + The HL graph. + id2word: + A map mapping word ID to word string. + Returns: + Return a list of decoded words. + """ + logging.info(f"{filename}, {nnet_output.shape}") + decodable = DecodableCtc(nnet_output.cpu()) + + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(HL, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + logging.info(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + logging.info(f"failed to get linear symbol sequence for {filename}") + return [""] + + # are shifted by 1 during graph construction + hyps = [id2word[i] for i in osymbols_out] + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + logging.info("Loading torchscript model") + model = torch.jit.load(args.nn_model) + model.eval() + model.to(device) + + logging.info(f"Loading HL from {args.HL}") + HL = kaldifst.StdVectorFst.read(args.HL) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, expected_sample_rate=sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.shape[0] for f in features] + feature_lengths = torch.tensor(feature_lengths) + + supervisions = dict() + supervisions["sequence_idx"] = torch.arange(len(features)) + supervisions["start_frame"] = torch.zeros(len(features)) + supervisions["num_frames"] = feature_lengths + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + nnet_output, _, _ = model(features, supervisions) + feature_lengths = ((feature_lengths - 1) // 2 - 1) // 2 + + id2word = read_words(args.words) + + hyps = [] + for i in range(nnet_output.shape[0]): + hyp = decode( + filename=args.sound_files[i], + nnet_output=nnet_output[i, : feature_lengths[i]], + HL=HL, + id2word=id2word, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py new file mode 100755 index 000000000..216677a23 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This file shows how to use a torchscript model for decoding with HLG +on CPU using OpenFST and decoders from kaldi. + +Usage: + +(1) LibriSpeech conformer_ctc + + ./conformer_ctc/jit_pretrained_decode_with_HLG.py \ + --nn-model ./conformer_ctc/exp/cpu_jit.pt \ + --HLG ./data/lang_bpe_500/HLG.fst \ + --words ./data/lang_bpe_500/words.txt \ + ./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \ + ./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac + +(2) AIShell conformer_ctc + + ./conformer_ctc/jit_pretrained_decode_with_HLG.py \ + --nn-model ./conformer_ctc/exp/cpu_jit.pt \ + --HLG ./data/lang_char/HLG.fst \ + --words ./data/lang_char/words.txt \ + ./BAC009S0764W0121.wav \ + ./BAC009S0764W0122.wav \ + ./BAC009S0764W0123.wav + +Note that to generate ./conformer_ctc/exp/cpu_jit.pt, +you can use ./export.py --jit 1 +""" + +import argparse +import logging +import math +from typing import Dict, List + +import kaldifeat +import kaldifst +import torch +import torchaudio +from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="""Path to the torchscript model. + You can use ./conformer_ctc/export.py --jit 1 + to obtain it + """, + ) + + parser.add_argument( + "--words", + type=str, + required=True, + help="Path to words.txt", + ) + + parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.fst") + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. ", + ) + + return parser + + +def read_words(words_txt: str) -> Dict[int, str]: + id2word = dict() + with open(words_txt, encoding="utf-8") as f: + for line in f: + word, idx = line.strip().split() + id2word[int(idx)] = word + + return id2word + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + if sample_rate != expected_sample_rate: + wave = torchaudio.functional.resample( + wave, + orig_freq=sample_rate, + new_freq=expected_sample_rate, + ) + + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + nnet_output: torch.Tensor, + HLG: kaldifst, + id2word: Dict[int, str], +) -> List[str]: + """ + Args: + filename: + Path to the filename for decoding. Used for debugging. + nnet_output: + A 2-D float32 tensor of shape (num_frames, vocab_size). It + contains output from log_softmax. + HLG: + The HLG graph. + id2word: + A map mapping word ID to word string. + Returns: + Return a list of decoded words. + """ + logging.info(f"{filename}, {nnet_output.shape}") + decodable = DecodableCtc(nnet_output.cpu()) + + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(HLG, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + logging.info(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + logging.info(f"failed to get linear symbol sequence for {filename}") + return [""] + + # are shifted by 1 during graph construction + hyps = [id2word[i] for i in osymbols_out] + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + logging.info("Loading torchscript model") + model = torch.jit.load(args.nn_model) + model.eval() + model.to(device) + + logging.info(f"Loading HLG from {args.HLG}") + HLG = kaldifst.StdVectorFst.read(args.HLG) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, expected_sample_rate=sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.shape[0] for f in features] + feature_lengths = torch.tensor(feature_lengths) + + supervisions = dict() + supervisions["sequence_idx"] = torch.arange(len(features)) + supervisions["start_frame"] = torch.zeros(len(features)) + supervisions["num_frames"] = feature_lengths + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + nnet_output, _, _ = model(features, supervisions) + feature_lengths = ((feature_lengths - 1) // 2 - 1) // 2 + + id2word = read_words(args.words) + + hyps = [] + for i in range(nnet_output.shape[0]): + hyp = decode( + filename=args.sound_files[i], + nnet_output=nnet_output[i, : feature_lengths[i]], + HLG=HLG, + id2word=id2word, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 99fe64793..828106f41 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -557,7 +557,6 @@ def train_one_epoch( ) if batch_idx % params.log_interval == 0: - if tb_writer is not None: loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py index c5a05d349..ca21bd6bf 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py @@ -953,7 +953,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py index cfd365207..ab046557f 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py @@ -401,6 +401,7 @@ def export_decoder_model_onnx( vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py index 5d7e2dfcd..a6c69d54f 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py @@ -136,6 +136,7 @@ class OnnxModel: self.encoder = ort.InferenceSession( encoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) self.init_encoder_states() @@ -184,6 +185,7 @@ class OnnxModel: self.decoder = ort.InferenceSession( decoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) decoder_meta = self.decoder.get_modelmeta().custom_metadata_map @@ -197,6 +199,7 @@ class OnnxModel: self.joiner = ort.InferenceSession( joiner_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) joiner_meta = self.joiner.get_modelmeta().custom_metadata_map diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py index 6bb37b017..23ddb6bec 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py @@ -953,7 +953,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py index 36067510c..420dc1065 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py @@ -955,7 +955,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/distillation_with_hubert.sh b/egs/librispeech/ASR/distillation_with_hubert.sh index 6aaa0333b..a5b0b85af 100755 --- a/egs/librispeech/ASR/distillation_with_hubert.sh +++ b/egs/librispeech/ASR/distillation_with_hubert.sh @@ -56,6 +56,8 @@ use_extracted_codebook=True # "hubert_xtralarge_ll60k" -> pretrained model without fintuing teacher_model_id=hubert_xtralarge_ll60k_finetune_ls960 +. shared/parse_options.sh || exit 1 + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} diff --git a/egs/librispeech/ASR/local/download_lm.py b/egs/librispeech/ASR/local/download_lm.py index da1648d06..5a36ff2a9 100755 --- a/egs/librispeech/ASR/local/download_lm.py +++ b/egs/librispeech/ASR/local/download_lm.py @@ -43,6 +43,7 @@ from pathlib import Path from tqdm.auto import tqdm + # This function is copied from lhotse def tqdm_urlretrieve_hook(t): """Wraps tqdm instance. diff --git a/egs/librispeech/ASR/local/prepare_lang_fst.py b/egs/librispeech/ASR/local/prepare_lang_fst.py new file mode 100755 index 000000000..fb1e7f9c0 --- /dev/null +++ b/egs/librispeech/ASR/local/prepare_lang_fst.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang) + +""" +This script takes as input lang_dir containing lexicon_disambig.txt, +tokens.txt, and words.txt and generates the following files: + + - H.fst + - HL.fst + - HLG.fst + +Note that saved files are in OpenFst binary format. + +Usage: + +./local/prepare_lang_fst.py \ + --lang-dir ./data/lang_phone \ + --has-silence 1 + +Or + +./local/prepare_lang_fst.py \ + --lang-dir ./data/lang_bpe_500 +""" + +import argparse +import logging +from pathlib import Path + +import kaldifst + +from icefall.ctc import ( + Lexicon, + add_disambig_self_loops, + add_one, + build_standard_ctc_topo, + make_lexicon_fst_no_silence, + make_lexicon_fst_with_silence, +) +from icefall.utils import str2bool + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + """, + ) + + parser.add_argument( + "--has-silence", + type=str2bool, + default=False, + help="True if the lexicon has silence.", + ) + + parser.add_argument( + "--ngram-G", + type=str, + help="""If not empty, it is the filename of G used to build HLG. + For instance, --ngram-G=./data/lm/G_3_fst.txt + """, + ) + + return parser.parse_args() + + +def build_HL( + H: kaldifst.StdVectorFst, + L: kaldifst.StdVectorFst, + has_silence: bool, + lexicon: Lexicon, +) -> kaldifst.StdVectorFst: + if has_silence: + # We also need to change the input labels of L + add_one(L, treat_ilabel_zero_specially=True, update_olabel=False) + else: + add_one(L, treat_ilabel_zero_specially=False, update_olabel=False) + + # Invoke add_disambig_self_loops() so that it eats the disambig symbols + # from L after composition + add_disambig_self_loops( + H, + start=lexicon.token2id["#0"] + 1, + end=lexicon.max_disambig_id + 1, + ) + + kaldifst.arcsort(H, sort_type="olabel") + kaldifst.arcsort(L, sort_type="ilabel") + + HL = kaldifst.compose(H, L) + kaldifst.determinize_star(HL) + + disambig0 = lexicon.token2id["#0"] + 1 + max_disambig = lexicon.max_disambig_id + 1 + for state in kaldifst.StateIterator(HL): + for arc in kaldifst.ArcIterator(HL, state): + # If treat_ilabel_zero_specially is False, we always change it + # Otherwise, we only change non-zero input labels + if disambig0 <= arc.ilabel <= max_disambig: + arc.ilabel = 0 + + # Note: We are not composing L with G, so there is no need to add + # self-loops to L to handle #0 + + return HL + + +def build_HLG( + H: kaldifst.StdVectorFst, + L: kaldifst.StdVectorFst, + G: kaldifst.StdVectorFst, + has_silence: bool, + lexicon: Lexicon, +) -> kaldifst.StdVectorFst: + if has_silence: + # We also need to change the input labels of L + add_one(L, treat_ilabel_zero_specially=True, update_olabel=False) + else: + add_one(L, treat_ilabel_zero_specially=False, update_olabel=False) + + # add-self-loops + token_disambig0 = lexicon.token2id["#0"] + 1 + word_disambig0 = lexicon.word2id["#0"] + + kaldifst.add_self_loops(L, isyms=[token_disambig0], osyms=[word_disambig0]) + + kaldifst.arcsort(L, sort_type="olabel") + kaldifst.arcsort(G, sort_type="ilabel") + LG = kaldifst.compose(L, G) + kaldifst.determinize_star(LG) + kaldifst.minimize_encoded(LG) + + kaldifst.arcsort(LG, sort_type="ilabel") + + # Invoke add_disambig_self_loops() so that it eats the disambig symbols + # from L after composition + add_disambig_self_loops( + H, + start=lexicon.token2id["#0"] + 1, + end=lexicon.max_disambig_id + 1, + ) + + kaldifst.arcsort(H, sort_type="olabel") + + HLG = kaldifst.compose(H, LG) + kaldifst.determinize_star(HLG) + + disambig0 = lexicon.token2id["#0"] + 1 + max_disambig = lexicon.max_disambig_id + 1 + for state in kaldifst.StateIterator(HLG): + for arc in kaldifst.ArcIterator(HLG, state): + # If treat_ilabel_zero_specially is False, we always change it + # Otherwise, we only change non-zero input labels + if disambig0 <= arc.ilabel <= max_disambig: + arc.ilabel = 0 + return HLG + + +def copy_fst(fst): + # Please don't use fst.copy() + return kaldifst.StdVectorFst(fst) + + +def main(): + args = get_args() + lang_dir = args.lang_dir + + lexicon = Lexicon(lang_dir) + + logging.info("Building standard CTC topology") + max_token_id = max(lexicon.tokens) + H = build_standard_ctc_topo(max_token_id=max_token_id) + + # We need to add one to all tokens since we want to use ID 0 + # for epsilon + add_one(H, treat_ilabel_zero_specially=False, update_olabel=True) + H.write(f"{lang_dir}/H.fst") + + logging.info("Building L") + # Now for HL + + if args.has_silence: + L = make_lexicon_fst_with_silence(lexicon, attach_symbol_table=False) + else: + L = make_lexicon_fst_no_silence(lexicon, attach_symbol_table=False) + + logging.info("Building HL") + HL = build_HL( + H=copy_fst(H), + L=copy_fst(L), + has_silence=args.has_silence, + lexicon=lexicon, + ) + HL.write(f"{lang_dir}/HL.fst") + + if not args.ngram_G: + logging.info("Skip building HLG") + return + + logging.info("Building HLG") + with open(args.ngram_G) as f: + G = kaldifst.compile( + s=f.read(), + acceptor=False, + ) + + HLG = build_HLG(H=H, L=L, G=G, has_silence=args.has_silence, lexicon=lexicon) + HLG.write(f"{lang_dir}/HLG.fst") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/long_file_recog/beam_search.py b/egs/librispeech/ASR/long_file_recog/beam_search.py index f8c31861c..b65e9d40a 100644 --- a/egs/librispeech/ASR/long_file_recog/beam_search.py +++ b/egs/librispeech/ASR/long_file_recog/beam_search.py @@ -236,7 +236,7 @@ def greedy_search_batch( encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) offset = 0 - for (t, batch_size) in enumerate(batch_size_list): + for t, batch_size in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -507,7 +507,7 @@ def modified_beam_search( offset = 0 finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): + for t, batch_size in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] diff --git a/egs/librispeech/ASR/long_file_recog/merge_chunks.py b/egs/librispeech/ASR/long_file_recog/merge_chunks.py index d38d9c86a..9e31e00d5 100755 --- a/egs/librispeech/ASR/long_file_recog/merge_chunks.py +++ b/egs/librispeech/ASR/long_file_recog/merge_chunks.py @@ -162,7 +162,6 @@ def merge_chunks( futures = [] with ThreadPoolExecutor(max_workers=1) as executor: - for cut in cuts_chunk: cur_rec_id = cut.recording.id if len(cut_list) == 0: diff --git a/egs/librispeech/ASR/long_file_recog/recognize.py b/egs/librispeech/ASR/long_file_recog/recognize.py index 96c83f859..466253446 100755 --- a/egs/librispeech/ASR/long_file_recog/recognize.py +++ b/egs/librispeech/ASR/long_file_recog/recognize.py @@ -264,6 +264,7 @@ def decode_dataset( - timestamps of reference transcript - timestamps of predicted result """ + # Background worker to add alignemnt and save cuts to disk. def _save_worker( cuts: List[Cut], diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/test_model.py b/egs/librispeech/ASR/lstm_transducer_stateless/test_model.py index 03dfe1997..91ef53e24 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/test_model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/test_model.py @@ -57,8 +57,7 @@ def test_model(): convert_scaled_to_non_scaled(model, inplace=True) - if not os.path.exists(params.exp_dir): - os.path.mkdir(params.exp_dir) + params.exp_dir.mkdir(exist_ok=True) encoder_filename = params.exp_dir / "encoder_jit_trace.pt" export_encoder_model_jit_trace(model.encoder, encoder_filename) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py index 89ced388c..2a52e2eec 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py @@ -359,6 +359,7 @@ def export_decoder_model_onnx( vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py index 6b6cb893f..c543628ff 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py @@ -356,6 +356,7 @@ def export_decoder_model_onnx( vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py index fb9e121e5..06159e56a 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py @@ -129,6 +129,7 @@ class OnnxModel: self.encoder = ort.InferenceSession( encoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) self.init_encoder_states() @@ -166,6 +167,7 @@ class OnnxModel: self.decoder = ort.InferenceSession( decoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) decoder_meta = self.decoder.get_modelmeta().custom_metadata_map @@ -179,6 +181,7 @@ class OnnxModel: self.joiner = ort.InferenceSession( joiner_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) joiner_meta = self.joiner.get_modelmeta().custom_metadata_map diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py index 34d2e5630..487fc2114 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py @@ -172,30 +172,35 @@ class Model: self.encoder = ort.InferenceSession( args.encoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) def init_decoder(self, args): self.decoder = ort.InferenceSession( args.decoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) def init_joiner(self, args): self.joiner = ort.InferenceSession( args.joiner_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) def init_joiner_encoder_proj(self, args): self.joiner_encoder_proj = ort.InferenceSession( args.joiner_encoder_proj_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) def init_joiner_decoder_proj(self, args): self.joiner_decoder_proj = ort.InferenceSession( args.joiner_decoder_proj_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) def run_encoder(self, x, h0, c0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 8ce1eb478..93d010ea8 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -242,6 +242,10 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then $lang_dir/L_disambig.pt \ $lang_dir/L_disambig.fst fi + + if [ ! -f $lang_dir/HL.fst ]; then + ./local/prepare_lang_fst.py --lang-dir $lang_dir --ngram-G ./data/lm/G_3_gram.fst.txt + fi done fi diff --git a/egs/librispeech/ASR/pruned2_knowledge/optim.py b/egs/librispeech/ASR/pruned2_knowledge/optim.py index 76cd4e11e..9f287ce70 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/optim.py +++ b/egs/librispeech/ASR/pruned2_knowledge/optim.py @@ -66,7 +66,6 @@ class Eve(Optimizer): weight_decay=1e-3, target_rms=0.1, ): - if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py index 77e06d3b7..a4899f7bd 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/train.py +++ b/egs/librispeech/ASR/pruned2_knowledge/train.py @@ -811,7 +811,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py index 282238c13..0a2132e56 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py @@ -307,6 +307,7 @@ def export_decoder_model_onnx( vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 3298568a3..7fcd242fc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -719,7 +719,7 @@ def greedy_search_batch( encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) offset = 0 - for (t, batch_size) in enumerate(batch_size_list): + for t, batch_size in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -1019,7 +1019,7 @@ def modified_beam_search( offset = 0 finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): + for t, batch_size in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -1227,7 +1227,7 @@ def modified_beam_search_lm_rescore( offset = 0 finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): + for t, batch_size in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -1427,7 +1427,7 @@ def modified_beam_search_lm_rescore_LODR( offset = 0 finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): + for t, batch_size in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -2608,7 +2608,6 @@ def modified_beam_search_LODR( context_score = 0 new_context_state = None if context_graph is None else hyp.context_state if new_token not in (blank_id, unk_id): - if context_graph is not None: ( context_score, @@ -2758,7 +2757,7 @@ def modified_beam_search_lm_shallow_fusion( offset = 0 finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): + for t, batch_size in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] # get batch @@ -2900,7 +2899,6 @@ def modified_beam_search_lm_shallow_fusion( new_token = topk_token_indexes[k] new_timestamp = hyp.timestamp[:] if new_token not in (blank_id, unk_id): - ys.append(new_token) new_timestamp.append(t) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index 2d7f557ad..f54bc2709 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -66,7 +66,6 @@ class Eve(Optimizer): weight_decay=1e-3, target_rms=0.1, ): - if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 963ebdc2d..91d64c1df 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -528,7 +528,6 @@ class ScaledLSTM(nn.LSTM): return with torch.cuda.device_of(first_fw): - # Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is # an inplace operation on self._flat_weights with torch.no_grad(): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py index 26dea7e11..2685ea95a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py @@ -312,6 +312,7 @@ def export_decoder_model_onnx( vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py index e10915086..de3e03da6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py @@ -150,12 +150,14 @@ class OnnxModel: self.encoder = ort.InferenceSession( encoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) def init_decoder(self, decoder_model_filename: str): self.decoder = ort.InferenceSession( decoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) decoder_meta = self.decoder.get_modelmeta().custom_metadata_map @@ -169,6 +171,7 @@ class OnnxModel: self.joiner = ort.InferenceSession( joiner_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) joiner_meta = self.joiner.get_modelmeta().custom_metadata_map diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py index 810da8da6..b98248128 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py @@ -78,6 +78,7 @@ def test_conv2d_subsampling(): session = ort.InferenceSession( filename, sess_options=options, + providers=["CPUExecutionProvider"], ) input_nodes = session.get_inputs() @@ -133,6 +134,7 @@ def test_rel_pos(): session = ort.InferenceSession( filename, sess_options=options, + providers=["CPUExecutionProvider"], ) input_nodes = session.get_inputs() @@ -220,6 +222,7 @@ def test_conformer_encoder_layer(): session = ort.InferenceSession( filename, sess_options=options, + providers=["CPUExecutionProvider"], ) input_nodes = session.get_inputs() @@ -304,6 +307,7 @@ def test_conformer_encoder(): session = ort.InferenceSession( filename, sess_options=options, + providers=["CPUExecutionProvider"], ) input_nodes = session.get_inputs() @@ -359,6 +363,7 @@ def test_conformer(): session = ort.InferenceSession( filename, sess_options=options, + providers=["CPUExecutionProvider"], ) input_nodes = session.get_inputs() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py index 549fb13c9..b90d81dcf 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py @@ -404,6 +404,7 @@ def export_decoder_model_onnx( vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py index fff0fcdd5..02aa24f2c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py @@ -335,6 +335,7 @@ def export_decoder_model_onnx( vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py index 29be4c655..6e290e799 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py @@ -138,6 +138,7 @@ class OnnxModel: self.encoder = ort.InferenceSession( encoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) self.init_encoder_states() @@ -185,6 +186,7 @@ class OnnxModel: self.decoder = ort.InferenceSession( decoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) decoder_meta = self.decoder.get_modelmeta().custom_metadata_map @@ -198,6 +200,7 @@ class OnnxModel: self.joiner = ort.InferenceSession( joiner_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) joiner_meta = self.joiner.get_modelmeta().custom_metadata_map diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 3b5a635e4..66dc5f991 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -1003,7 +1003,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py index 14ff86f23..3bca7db2c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py @@ -56,7 +56,6 @@ class CodebookIndexExtractor: """ def __init__(self, params: AttributeDict): - self.params = params params.subsets = ["clean-100"] if self.params.full_libri: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py b/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py index 76cd56bbb..bfb5fe609 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py @@ -111,7 +111,7 @@ def batch_force_alignment( offset = 0 finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): + for t, batch_size in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py index b085a1817..bfd019ff5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py @@ -71,6 +71,10 @@ class Decoder(nn.Module): groups=decoder_dim // 4, # group size == 4 bias=False, ) + else: + # To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'` + # when inference with torch.jit.script and context_size == 1 + self.conv = nn.Identity() def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py index 11c885f4d..b75548f8b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py @@ -329,6 +329,7 @@ def export_decoder_model_onnx( vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py index 3ee2b7d65..4e261dbc1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py @@ -1132,7 +1132,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index aa3cef338..8ab3589da 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -117,7 +117,7 @@ class BatchedOptimizer(Optimizer): yield tuples # <-- calling code will do the actual optimization here! - for ((stacked_params, _state, _names), batch) in zip(tuples, batches): + for (stacked_params, _state, _names), batch in zip(tuples, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) @@ -181,7 +181,6 @@ class ScaledAdam(BatchedOptimizer): parameters_names=None, show_dominant_parameters=True, ): - assert parameters_names is not None, ( "Please prepare parameters_names," "which is a List[List[str]]. Each List[str] is for a group" @@ -224,9 +223,7 @@ class ScaledAdam(BatchedOptimizer): batch = True for group, group_params_names in zip(self.param_groups, self.parameters_names): - with self.batched_params(group["params"], group_params_names) as batches: - # batches is list of pairs (stacked_param, state). stacked_param is like # a regular parameter, and will have a .grad, but the 1st dim corresponds to # a stacking dim, it is not a real dim. @@ -325,7 +322,7 @@ class ScaledAdam(BatchedOptimizer): clipping_update_period = group["clipping_update_period"] tot_sumsq = torch.tensor(0.0, device=first_p.device) - for (p, state, param_names) in tuples: + for p, state, param_names in tuples: grad = p.grad if grad.is_sparse: raise RuntimeError( @@ -410,7 +407,7 @@ class ScaledAdam(BatchedOptimizer): from tuples, we still pass it to save some time. """ all_sumsq_orig = {} - for (p, state, batch_param_names) in tuples: + for p, state, batch_param_names in tuples: # p is a stacked batch parameters. batch_grad = p.grad if p.numel() == p.shape[0]: # a batch of scalars @@ -426,7 +423,6 @@ class ScaledAdam(BatchedOptimizer): for name, sumsq_orig, rms, grad in zip( batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad ): - proportion_orig = sumsq_orig / tot_sumsq all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) @@ -1039,7 +1035,7 @@ def _test_scaled_adam(hidden_dim: int): # if epoch == 130: # opts = diagnostics.TensorDiagnosticOptions( - # 2 ** 22 + # 512 # ) # allow 4 megabytes per sub-module # diagnostic = diagnostics.attach_diagnostics(m, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py index 1e9b67226..f3f7b1ea9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py @@ -74,6 +74,7 @@ def test_conv2d_subsampling(): session = ort.InferenceSession( filename, sess_options=options, + providers=["CPUExecutionProvider"], ) input_nodes = session.get_inputs() @@ -128,6 +129,7 @@ def test_rel_pos(): session = ort.InferenceSession( filename, sess_options=options, + providers=["CPUExecutionProvider"], ) input_nodes = session.get_inputs() @@ -204,6 +206,7 @@ def test_zipformer_encoder_layer(): session = ort.InferenceSession( filename, sess_options=options, + providers=["CPUExecutionProvider"], ) input_nodes = session.get_inputs() @@ -284,6 +287,7 @@ def test_zipformer_encoder(): session = ort.InferenceSession( filename, sess_options=options, + providers=["CPUExecutionProvider"], ) input_nodes = session.get_inputs() @@ -338,6 +342,7 @@ def test_zipformer(): session = ort.InferenceSession( filename, sess_options=options, + providers=["CPUExecutionProvider"], ) input_nodes = session.get_inputs() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 2b4d51089..fac3706d2 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -1028,7 +1028,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py index b387968a9..d8fa08372 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py @@ -1052,7 +1052,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py index 8ff02fbcb..494a34d97 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py @@ -326,41 +326,49 @@ def main(): encoder = ort.InferenceSession( args.encoder_model_filename, sess_options=session_opts, + providers=["CPUExecutionProvider"], ) decoder = ort.InferenceSession( args.decoder_model_filename, sess_options=session_opts, + providers=["CPUExecutionProvider"], ) joiner = ort.InferenceSession( args.joiner_model_filename, sess_options=session_opts, + providers=["CPUExecutionProvider"], ) joiner_encoder_proj = ort.InferenceSession( args.joiner_encoder_proj_model_filename, sess_options=session_opts, + providers=["CPUExecutionProvider"], ) joiner_decoder_proj = ort.InferenceSession( args.joiner_decoder_proj_model_filename, sess_options=session_opts, + providers=["CPUExecutionProvider"], ) lconv = ort.InferenceSession( args.lconv_filename, sess_options=session_opts, + providers=["CPUExecutionProvider"], ) frame_reducer = ort.InferenceSession( args.frame_reducer_filename, sess_options=session_opts, + providers=["CPUExecutionProvider"], ) ctc_output = ort.InferenceSession( args.ctc_output_filename, sess_options=session_opts, + providers=["CPUExecutionProvider"], ) sp = spm.SentencePieceProcessor() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py index 23fb6f497..25a1aa674 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py @@ -1042,7 +1042,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py index 8653126de..2de56837e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py @@ -413,6 +413,7 @@ def export_decoder_model_onnx( context_size = decoder_model.decoder.context_size vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py index 6f84d79b4..d71080760 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py @@ -401,6 +401,7 @@ def export_decoder_model_onnx( context_size = decoder_model.decoder.context_size vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py index 8192e01fd..04861ea37 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py @@ -130,6 +130,7 @@ class OnnxModel: self.encoder = ort.InferenceSession( encoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) self.init_encoder_states() @@ -229,6 +230,7 @@ class OnnxModel: self.decoder = ort.InferenceSession( decoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) decoder_meta = self.decoder.get_modelmeta().custom_metadata_map @@ -242,6 +244,7 @@ class OnnxModel: self.joiner = ort.InferenceSession( joiner_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) joiner_meta = self.joiner.get_modelmeta().custom_metadata_map diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py index 99090b2c1..2d915ff87 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -1029,7 +1029,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py index 9be629149..aa6c0668a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py @@ -1030,7 +1030,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py index a5c422959..c7e45564f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -865,7 +865,7 @@ class ZipformerEncoderLayer(nn.Module): return final_dropout_rate else: return initial_dropout_rate - ( - initial_dropout_rate * final_dropout_rate + initial_dropout_rate - final_dropout_rate ) * (self.batch_count / warmup_period) def forward( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py index b494253d6..565dc7a16 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py @@ -1141,7 +1141,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index bee414292..3f271c5b4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -1154,7 +1154,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py index 5fe92172e..0b982f4bf 100644 --- a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py @@ -230,7 +230,9 @@ class Conformer(Transformer): x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask ) # (T, B, F) else: - x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) + x = self.encoder( + x, pos_emb, src_key_padding_mask=src_key_padding_mask + ) # (T, B, F) if self.normalize_before: x = self.after_norm(x) diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/train.py b/egs/librispeech/ASR/streaming_conformer_ctc/train.py index bb55ed6bb..14d7274c2 100755 --- a/egs/librispeech/ASR/streaming_conformer_ctc/train.py +++ b/egs/librispeech/ASR/streaming_conformer_ctc/train.py @@ -543,7 +543,6 @@ def train_one_epoch( ) if batch_idx % params.log_interval == 0: - if tb_writer is not None: loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index 0aa1587ba..90245ed46 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -463,7 +463,6 @@ def train_one_epoch( f"tot_loss[{tot_loss}], batch size: {batch_size}" ) if batch_idx % params.log_interval == 0: - if tb_writer is not None: loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train diff --git a/egs/librispeech/ASR/transducer/train.py b/egs/librispeech/ASR/transducer/train.py index f2a09346c..9ac6b7d03 100755 --- a/egs/librispeech/ASR/transducer/train.py +++ b/egs/librispeech/ASR/transducer/train.py @@ -513,7 +513,6 @@ def train_one_epoch( ) if batch_idx % params.log_interval == 0: - if tb_writer is not None: loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py index a6f2bd08c..92134116c 100755 --- a/egs/librispeech/ASR/transducer_lstm/train.py +++ b/egs/librispeech/ASR/transducer_lstm/train.py @@ -517,7 +517,6 @@ def train_one_epoch( ) if batch_idx % params.log_interval == 0: - if tb_writer is not None: loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train diff --git a/egs/librispeech/ASR/zipformer/decoder.py b/egs/librispeech/ASR/zipformer/decoder.py index e8db988f6..e77e54118 100644 --- a/egs/librispeech/ASR/zipformer/decoder.py +++ b/egs/librispeech/ASR/zipformer/decoder.py @@ -61,10 +61,15 @@ class Decoder(nn.Module): ) # the balancers are to avoid any drift in the magnitude of the # embeddings, which would interact badly with parameter averaging. - self.balancer = Balancer(decoder_dim, channel_dim=-1, - min_positive=0.0, max_positive=1.0, - min_abs=0.5, max_abs=1.0, - prob=0.05) + self.balancer = Balancer( + decoder_dim, + channel_dim=-1, + min_positive=0.0, + max_positive=1.0, + min_abs=0.5, + max_abs=1.0, + prob=0.05, + ) self.blank_id = blank_id @@ -81,10 +86,15 @@ class Decoder(nn.Module): groups=decoder_dim // 4, # group size == 4 bias=False, ) - self.balancer2 = Balancer(decoder_dim, channel_dim=-1, - min_positive=0.0, max_positive=1.0, - min_abs=0.5, max_abs=1.0, - prob=0.05) + self.balancer2 = Balancer( + decoder_dim, + channel_dim=-1, + min_positive=0.0, + max_positive=1.0, + min_abs=0.5, + max_abs=1.0, + prob=0.05, + ) def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ @@ -107,9 +117,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/zipformer/export-onnx-ctc.py b/egs/librispeech/ASR/zipformer/export-onnx-ctc.py new file mode 100755 index 000000000..3345d20d3 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/export-onnx-ctc.py @@ -0,0 +1,436 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a CTC model from PyTorch to ONNX. + +Note that the model is trained using both transducer and CTC loss. This script +exports only the CTC head. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zipformer/export-onnx-ctc.py \ + --use-transducer 0 \ + --use-ctc 1 \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal False \ + --chunk-size 16 \ + --left-context-frames 128 + +It will generate the following 2 files inside $repo/exp: + + - model.onnx + - model.int8.onnx + +See ./onnx_pretrained_ctc.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import k2 +import onnx +import torch +import torch.nn as nn +from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_model, get_params +from zipformer import Zipformer2 + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxModel(nn.Module): + """A wrapper for encoder_embed, Zipformer, and ctc_output layer""" + + def __init__( + self, + encoder: Zipformer2, + encoder_embed: nn.Module, + ctc_output: nn.Module, + ): + """ + Args: + encoder: + A Zipformer encoder. + encoder_embed: + The first downsampling layer for zipformer. + """ + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + self.ctc_output = ctc_output + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Zipformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - log_probs, a 3-D tensor of shape (N, T', vocab_size) + - log_probs_len, a 1-D int64 tensor of shape (N,) + """ + x, x_lens = self.encoder_embed(x, x_lens) + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) + encoder_out, log_probs_len = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) + log_probs = self.ctc_output(encoder_out) + + return log_probs, log_probs_len + + +def export_ctc_model_onnx( + model: OnnxModel, + filename: str, + opset_version: int = 11, +) -> None: + """Export the given model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - log_probs, a tensor of shape (N, T', joiner_dim) + - log_probs_len, a tensor of shape (N,) + + Args: + model: + The input model + filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + model = torch.jit.trace(model, (x, x_lens)) + + torch.onnx.export( + model, + (x, x_lens), + filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["log_probs", "log_probs_len"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "log_probs": {0: "N", 1: "T"}, + "log_probs_len": {0: "N"}, + }, + ) + + meta_data = { + "model_type": "zipformer2_ctc", + "version": "1", + "model_author": "k2-fsa", + "comment": "non-streaming zipformer2 CTC", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: + load_checkpoint( + f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False + ) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) + + model = OnnxModel( + encoder=model.encoder, + encoder_embed=model.encoder_embed, + ctc_output=model.ctc_output, + ) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"num parameters: {num_param}") + + opset_version = 13 + + logging.info("Exporting ctc model") + filename = params.exp_dir / f"model.onnx" + export_ctc_model_onnx( + model, + filename, + opset_version=opset_version, + ) + logging.info(f"Exported to {filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + filename_int8 = params.exp_dir / f"model.int8.onnx" + quantize_dynamic( + model_input=filename, + model_output=filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index a951aeef3..e2c7d7d95 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -506,6 +506,7 @@ def export_decoder_model_onnx( vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/librispeech/ASR/zipformer/export-onnx.py b/egs/librispeech/ASR/zipformer/export-onnx.py index e0d664009..3682f0b62 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx.py +++ b/egs/librispeech/ASR/zipformer/export-onnx.py @@ -353,6 +353,7 @@ def export_decoder_model_onnx( vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/librispeech/ASR/zipformer/joiner.py b/egs/librispeech/ASR/zipformer/joiner.py index f03cc930e..dfb0a0057 100644 --- a/egs/librispeech/ASR/zipformer/joiner.py +++ b/egs/librispeech/ASR/zipformer/joiner.py @@ -52,12 +52,13 @@ class Joiner(nn.Module): Returns: Return a tensor of shape (N, T, s_range, C). """ - assert encoder_out.ndim == decoder_out.ndim, (encoder_out.shape, decoder_out.shape) + assert encoder_out.ndim == decoder_out.ndim, ( + encoder_out.shape, + decoder_out.shape, + ) if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/zipformer/onnx_decode.py b/egs/librispeech/ASR/zipformer/onnx_decode.py index 2aca36ca9..356c2a830 100755 --- a/egs/librispeech/ASR/zipformer/onnx_decode.py +++ b/egs/librispeech/ASR/zipformer/onnx_decode.py @@ -303,7 +303,9 @@ def main(): for test_set, test_dl in zip(test_sets, test_dl): start_time = time.time() - results, total_duration = decode_dataset(dl=test_dl, model=model, token_table=token_table) + results, total_duration = decode_dataset( + dl=test_dl, model=model, token_table=token_table + ) end_time = time.time() elapsed_seconds = end_time - start_time rtf = elapsed_seconds / total_duration diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py index 500b2cd09..e62491444 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py @@ -146,6 +146,7 @@ class OnnxModel: self.encoder = ort.InferenceSession( encoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) self.init_encoder_states() @@ -236,6 +237,7 @@ class OnnxModel: self.decoder = ort.InferenceSession( decoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) decoder_meta = self.decoder.get_modelmeta().custom_metadata_map @@ -249,6 +251,7 @@ class OnnxModel: self.joiner = ort.InferenceSession( joiner_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) joiner_meta = self.joiner.get_modelmeta().custom_metadata_map diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained.py b/egs/librispeech/ASR/zipformer/onnx_pretrained.py index 032b07721..334376093 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained.py @@ -151,12 +151,14 @@ class OnnxModel: self.encoder = ort.InferenceSession( encoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) def init_decoder(self, decoder_model_filename: str): self.decoder = ort.InferenceSession( decoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) decoder_meta = self.decoder.get_modelmeta().custom_metadata_map @@ -170,6 +172,7 @@ class OnnxModel: self.joiner = ort.InferenceSession( joiner_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) joiner_meta = self.joiner.get_modelmeta().custom_metadata_map diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc.py new file mode 100755 index 000000000..eb5cee9cd --- /dev/null +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +""" +This script loads ONNX models and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13 +as an example to show how to use this file. + +1. Please follow ./export-onnx-ctc.py to get the onnx model. + +2. Run this file + +./zipformer/onnx_pretrained_ctc.py \ + --nn-model /path/to/model.onnx \ + --tokens /path/to/data/lang_bpe_500/tokens.txt \ + 1089-134686-0001.wav \ + 1221-135766-0001.wav \ + 1221-135766-0002.wav +""" + +import argparse +import logging +import math +from typing import List, Tuple + +import k2 +import kaldifeat +import onnxruntime as ort +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="Path to the onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + nn_model: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_model(nn_model) + + def init_model(self, nn_model: str): + self.model = ort.InferenceSession( + nn_model, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + meta = self.model.get_modelmeta().custom_metadata_map + print(meta) + + def __call__( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D float tensor of shape (N, T, C) + x_lens: + A 1-D int64 tensor of shape (N,) + Returns: + Return a tuple containing: + - A float tensor containing log_probs of shape (N, T, C) + - A int64 tensor containing log_probs_len of shape (N) + """ + out = self.model.run( + [ + self.model.get_outputs()[0].name, + self.model.get_outputs()[1].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]), torch.from_numpy(out[1]) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + model = OnnxModel( + nn_model=args.nn_model, + ) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + log_probs, log_probs_len = model(features, feature_lengths) + + token_table = k2.SymbolTable.from_file(args.tokens) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + + blank_id = 0 + s = "\n" + for i in range(log_probs.size(0)): + # greedy search + indexes = log_probs[i, : log_probs_len[i]].argmax(dim=-1) + token_ids = torch.unique_consecutive(indexes) + + token_ids = token_ids[token_ids != blank_id] + words = token_ids_to_words(token_ids.tolist()) + s += f"{args.sound_files[i]}:\n{words}\n\n" + + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py new file mode 100755 index 000000000..683a7dc20 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +""" +This script loads ONNX models and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13 +as an example to show how to use this file. + +1. Please follow ./export-onnx-ctc.py to get the onnx model. + +2. Run this file + +./zipformer/onnx_pretrained_ctc_H.py \ + --nn-model /path/to/model.onnx \ + --tokens /path/to/data/lang_bpe_500/tokens.txt \ + --H /path/to/H.fst \ + 1089-134686-0001.wav \ + 1221-135766-0001.wav \ + 1221-135766-0002.wav + +You can find exported ONNX models at +https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02 +""" + +import argparse +import logging +import math +from typing import List, Tuple + +import k2 +import kaldifeat +from typing import Dict +import kaldifst +import onnxruntime as ort +import torch +import torchaudio +from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="Path to the onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "--H", + type=str, + help="""Path to H.fst.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + nn_model: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_model(nn_model) + + def init_model(self, nn_model: str): + self.model = ort.InferenceSession( + nn_model, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + meta = self.model.get_modelmeta().custom_metadata_map + print(meta) + + def __call__( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D float tensor of shape (N, T, C) + x_lens: + A 1-D int64 tensor of shape (N,) + Returns: + Return a tuple containing: + - A float tensor containing log_probs of shape (N, T, C) + - A int64 tensor containing log_probs_len of shape (N) + """ + out = self.model.run( + [ + self.model.get_outputs()[0].name, + self.model.get_outputs()[1].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]), torch.from_numpy(out[1]) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + log_probs: torch.Tensor, + H: kaldifst, + id2token: Dict[int, str], +) -> List[str]: + """ + Args: + filename: + Path to the filename for decoding. Used for debugging. + log_probs: + A 2-D float32 tensor of shape (num_frames, vocab_size). It + contains output from log_softmax. + H: + The H graph. + id2word: + A map mapping token ID to word string. + Returns: + Return a list of decoded words. + """ + logging.info(f"{filename}, {log_probs.shape}") + decodable = DecodableCtc(log_probs.cpu()) + + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(H, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + logging.info(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + logging.info(f"failed to get linear symbol sequence for {filename}") + return [""] + + # tokens are incremented during graph construction + # are shifted by 1 during graph construction + hyps = [id2token[i - 1] for i in osymbols_out if i != 1] + hyps = "".join(hyps).split("\u2581") # unicode codepoint of ▁ + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + model = OnnxModel( + nn_model=args.nn_model, + ) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + logging.info(f"Loading H from {args.H}") + H = kaldifst.StdVectorFst.read(args.H) + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + log_probs, log_probs_len = model(features, feature_lengths) + + token_table = k2.SymbolTable.from_file(args.tokens) + + hyps = [] + for i in range(log_probs.shape[0]): + hyp = decode( + filename=args.sound_files[i], + log_probs=log_probs[i, : log_probs_len[i]], + H=H, + id2token=token_table, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py new file mode 100755 index 000000000..0b94bfa65 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py @@ -0,0 +1,275 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +""" +This script loads ONNX models and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13 +as an example to show how to use this file. + +1. Please follow ./export-onnx-ctc.py to get the onnx model. + +2. Run this file + +./zipformer/onnx_pretrained_ctc_HL.py \ + --nn-model /path/to/model.onnx \ + --words /path/to/data/lang_bpe_500/words.txt \ + --HL /path/to/HL.fst \ + 1089-134686-0001.wav \ + 1221-135766-0001.wav \ + 1221-135766-0002.wav + +You can find exported ONNX models at +https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02 +""" + +import argparse +import logging +import math +from typing import List, Tuple + +import k2 +import kaldifeat +from typing import Dict +import kaldifst +import onnxruntime as ort +import torch +import torchaudio +from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="Path to the onnx model. ", + ) + + parser.add_argument( + "--words", + type=str, + help="""Path to words.txt.""", + ) + + parser.add_argument( + "--HL", + type=str, + help="""Path to HL.fst.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + nn_model: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_model(nn_model) + + def init_model(self, nn_model: str): + self.model = ort.InferenceSession( + nn_model, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + meta = self.model.get_modelmeta().custom_metadata_map + print(meta) + + def __call__( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D float tensor of shape (N, T, C) + x_lens: + A 1-D int64 tensor of shape (N,) + Returns: + Return a tuple containing: + - A float tensor containing log_probs of shape (N, T, C) + - A int64 tensor containing log_probs_len of shape (N) + """ + out = self.model.run( + [ + self.model.get_outputs()[0].name, + self.model.get_outputs()[1].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]), torch.from_numpy(out[1]) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + log_probs: torch.Tensor, + HL: kaldifst, + id2word: Dict[int, str], +) -> List[str]: + """ + Args: + filename: + Path to the filename for decoding. Used for debugging. + log_probs: + A 2-D float32 tensor of shape (num_frames, vocab_size). It + contains output from log_softmax. + HL: + The HL graph. + id2word: + A map mapping word ID to word string. + Returns: + Return a list of decoded words. + """ + logging.info(f"{filename}, {log_probs.shape}") + decodable = DecodableCtc(log_probs.cpu()) + + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(HL, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + logging.info(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + logging.info(f"failed to get linear symbol sequence for {filename}") + return [""] + + # are shifted by 1 during graph construction + hyps = [id2word[i] for i in osymbols_out] + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + model = OnnxModel( + nn_model=args.nn_model, + ) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + logging.info(f"Loading HL from {args.HL}") + HL = kaldifst.StdVectorFst.read(args.HL) + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + log_probs, log_probs_len = model(features, feature_lengths) + + word_table = k2.SymbolTable.from_file(args.words) + + hyps = [] + for i in range(log_probs.shape[0]): + hyp = decode( + filename=args.sound_files[i], + log_probs=log_probs[i, : log_probs_len[i]], + HL=HL, + id2word=word_table, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py new file mode 100755 index 000000000..93569142a --- /dev/null +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py @@ -0,0 +1,275 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +""" +This script loads ONNX models and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13 +as an example to show how to use this file. + +1. Please follow ./export-onnx-ctc.py to get the onnx model. + +2. Run this file + +./zipformer/onnx_pretrained_ctc_HLG.py \ + --nn-model /path/to/model.onnx \ + --words /path/to/data/lang_bpe_500/words.txt \ + --HLG /path/to/HLG.fst \ + 1089-134686-0001.wav \ + 1221-135766-0001.wav \ + 1221-135766-0002.wav + +You can find exported ONNX models at +https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02 +""" + +import argparse +import logging +import math +from typing import List, Tuple + +import k2 +import kaldifeat +from typing import Dict +import kaldifst +import onnxruntime as ort +import torch +import torchaudio +from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="Path to the onnx model. ", + ) + + parser.add_argument( + "--words", + type=str, + help="""Path to words.txt.""", + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.fst.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + nn_model: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_model(nn_model) + + def init_model(self, nn_model: str): + self.model = ort.InferenceSession( + nn_model, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + meta = self.model.get_modelmeta().custom_metadata_map + print(meta) + + def __call__( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D float tensor of shape (N, T, C) + x_lens: + A 1-D int64 tensor of shape (N,) + Returns: + Return a tuple containing: + - A float tensor containing log_probs of shape (N, T, C) + - A int64 tensor containing log_probs_len of shape (N) + """ + out = self.model.run( + [ + self.model.get_outputs()[0].name, + self.model.get_outputs()[1].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]), torch.from_numpy(out[1]) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + log_probs: torch.Tensor, + HLG: kaldifst, + id2word: Dict[int, str], +) -> List[str]: + """ + Args: + filename: + Path to the filename for decoding. Used for debugging. + log_probs: + A 2-D float32 tensor of shape (num_frames, vocab_size). It + contains output from log_softmax. + HLG: + The HLG graph. + id2word: + A map mapping word ID to word string. + Returns: + Return a list of decoded words. + """ + logging.info(f"{filename}, {log_probs.shape}") + decodable = DecodableCtc(log_probs.cpu()) + + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(HLG, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + logging.info(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + logging.info(f"failed to get linear symbol sequence for {filename}") + return [""] + + # are shifted by 1 during graph construction + hyps = [id2word[i] for i in osymbols_out] + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + model = OnnxModel( + nn_model=args.nn_model, + ) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + logging.info(f"Loading HLG from {args.HLG}") + HLG = kaldifst.StdVectorFst.read(args.HLG) + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + log_probs, log_probs_len = model(features, feature_lengths) + + word_table = k2.SymbolTable.from_file(args.words) + + hyps = [] + for i in range(log_probs.shape[0]): + hyp = decode( + filename=args.sound_files[i], + log_probs=log_probs[i, : log_probs_len[i]], + HLG=HLG, + id2word=word_table, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index abfb2092c..8ee2b0eb4 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -116,7 +116,7 @@ class BatchedOptimizer(Optimizer): yield tuples # <-- calling code will do the actual optimization here! - for ((stacked_params, _state, _names), batch) in zip(tuples, batches): + for (stacked_params, _state, _names), batch in zip(tuples, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) @@ -181,7 +181,6 @@ class ScaledAdam(BatchedOptimizer): size_update_period=4, clipping_update_period=100, ): - defaults = dict( lr=lr, clipping_scale=clipping_scale, @@ -299,8 +298,8 @@ class ScaledAdam(BatchedOptimizer): # the input is groups of parameter or named parameter. for cur_group in iterable_or_groups: assert "named_params" in cur_group - name_list = [ x[0] for x in cur_group["named_params"] ] - p_list = [ x[1] for x in cur_group["named_params"] ] + name_list = [x[0] for x in cur_group["named_params"]] + p_list = [x[1] for x in cur_group["named_params"]] del cur_group["named_params"] cur_group["params"] = p_list param_groups.append(cur_group) @@ -327,9 +326,7 @@ class ScaledAdam(BatchedOptimizer): batch = True for group, group_params_names in zip(self.param_groups, self.parameters_names): - with self.batched_params(group["params"], group_params_names) as batches: - # batches is list of pairs (stacked_param, state). stacked_param is like # a regular parameter, and will have a .grad, but the 1st dim corresponds to # a stacking dim, it is not a real dim. @@ -428,7 +425,7 @@ class ScaledAdam(BatchedOptimizer): clipping_update_period = group["clipping_update_period"] tot_sumsq = torch.tensor(0.0, device=first_p.device) - for (p, state, param_names) in tuples: + for p, state, param_names in tuples: grad = p.grad if grad.is_sparse: raise RuntimeError( @@ -494,6 +491,12 @@ class ScaledAdam(BatchedOptimizer): if self.show_dominant_parameters: assert p.shape[0] == len(param_names) self._show_gradient_dominating_parameter(tuples, tot_sumsq) + if ans != ans: # e.g. ans is nan + ans = 0.0 + if ans == 0.0: + for p, state, param_names in tuples: + p.grad.zero_() # get rid of infinity() + return ans def _show_gradient_dominating_parameter( @@ -513,7 +516,7 @@ class ScaledAdam(BatchedOptimizer): from tuples, we still pass it to save some time. """ all_sumsq_orig = {} - for (p, state, batch_param_names) in tuples: + for p, state, batch_param_names in tuples: # p is a stacked batch parameters. batch_grad = p.grad if p.numel() == p.shape[0]: # a batch of scalars @@ -529,7 +532,6 @@ class ScaledAdam(BatchedOptimizer): for name, sumsq_orig, rms, grad in zip( batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad ): - proportion_orig = sumsq_orig / tot_sumsq all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) @@ -577,7 +579,7 @@ class ScaledAdam(BatchedOptimizer): grad = p.grad if clipping_scale != 1.0: - grad = grad * clipping_scale + grad *= clipping_scale step = state["step"] delta = state["delta"] @@ -667,8 +669,7 @@ class ScaledAdam(BatchedOptimizer): # We have to look at the trained model for parameters at or around the # param_max_rms, because sometimes they can indicate a problem with the # topology or settings. - scale_step = torch.minimum(scale_step, - (param_max_rms - param_rms) / param_rms) + scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms) delta = state["delta"] # the factor of (1-beta1) relates to momentum. @@ -879,7 +880,8 @@ class Eden(LRScheduler): warmup_factor = ( 1.0 if self.batch >= self.warmup_batches - else self.warmup_start + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) + else self.warmup_start + + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) # else 0.5 + 0.5 * (self.batch / self.warmup_batches) ) @@ -1111,7 +1113,7 @@ def _test_scaled_adam(hidden_dim: int): # if epoch == 130: # opts = diagnostics.TensorDiagnosticOptions( - # 2 ** 22 + # 512 # ) # allow 4 megabytes per sub-module # diagnostic = diagnostics.attach_diagnostics(m, opts) diff --git a/egs/librispeech/ASR/zipformer/profile.py b/egs/librispeech/ASR/zipformer/profile.py index b460b5338..57f44a90a 100755 --- a/egs/librispeech/ASR/zipformer/profile.py +++ b/egs/librispeech/ASR/zipformer/profile.py @@ -100,17 +100,13 @@ class Model(nn.Module): self.encoder_embed = encoder_embed self.encoder_proj = encoder_proj - def forward( - self, feature: Tensor, feature_lens: Tensor - ) -> Tuple[Tensor, Tensor]: + def forward(self, feature: Tensor, feature_lens: Tensor) -> Tuple[Tensor, Tensor]: x, x_lens = self.encoder_embed(feature, feature_lens) src_key_padding_mask = make_pad_mask(x_lens) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_out, encoder_out_lens = self.encoder( - x, x_lens, src_key_padding_mask - ) + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) encoder_out = encoder_out.permute(1, 0, 2) # (N, T, C) -> (T, N, C) logits = self.encoder_proj(encoder_out) @@ -168,9 +164,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 7c98ef045..c0f1e3087 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -25,6 +25,7 @@ import math import torch.nn as nn from torch import Tensor + def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: max_value = torch.max(x, y) diff = torch.abs(x - y) @@ -55,28 +56,34 @@ def logaddexp(x: Tensor, y: Tensor) -> Tensor: # for torch.jit.trace() return torch.logaddexp(x, y) + class PiecewiseLinear(object): """ Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y] respectively. """ + def __init__(self, *args): assert len(args) >= 1, len(args) if len(args) == 1 and isinstance(args[0], PiecewiseLinear): self.pairs = list(args[0].pairs) else: - self.pairs = [ (float(x), float(y)) for x,y in args ] - for (x,y) in self.pairs: + self.pairs = [(float(x), float(y)) for x, y in args] + for x, y in self.pairs: assert isinstance(x, (float, int)), type(x) assert isinstance(y, (float, int)), type(y) for i in range(len(self.pairs) - 1): - assert self.pairs[i + 1][0] > self.pairs[i][0], (i, self.pairs[i], self.pairs[i + 1]) + assert self.pairs[i + 1][0] > self.pairs[i][0], ( + i, + self.pairs[i], + self.pairs[i + 1], + ) def __str__(self): # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))' - return f'PiecewiseLinear({str(self.pairs)[1:-1]})' + return f"PiecewiseLinear({str(self.pairs)[1:-1]})" def __call__(self, x): if x <= self.pairs[0][0]: @@ -93,37 +100,36 @@ class PiecewiseLinear(object): assert False def __mul__(self, alpha): - return PiecewiseLinear( - * [(x, y * alpha) for x, y in self.pairs]) + return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs]) def __add__(self, x): if isinstance(x, (float, int)): - return PiecewiseLinear( - * [(p[0], p[1] + x) for p in self.pairs]) + return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs]) s, x = self.get_common_basis(x) return PiecewiseLinear( - * [(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)]) + *[(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)] + ) def max(self, x): if isinstance(x, (float, int)): - x = PiecewiseLinear( (0, x) ) + x = PiecewiseLinear((0, x)) s, x = self.get_common_basis(x, include_crossings=True) return PiecewiseLinear( - * [(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]) + *[(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] + ) def min(self, x): if isinstance(x, float) or isinstance(x, int): - x = PiecewiseLinear( (0, x) ) + x = PiecewiseLinear((0, x)) s, x = self.get_common_basis(x, include_crossings=True) return PiecewiseLinear( - * [ (sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]) + *[(sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] + ) def __eq__(self, other): return self.pairs == other.pairs - def get_common_basis(self, - p: 'PiecewiseLinear', - include_crossings: bool = False): + def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False): """ Returns (self_mod, p_mod) which are equivalent piecewise linear functions to self and p, but with the same x values. @@ -135,28 +141,30 @@ class PiecewiseLinear(object): assert isinstance(p, PiecewiseLinear), type(p) # get sorted x-values without repetition. - x_vals = sorted(set([ x for x, _ in self.pairs ] + [ x for x, _ in p.pairs ])) - y_vals1 = [ self(x) for x in x_vals ] - y_vals2 = [ p(x) for x in x_vals ] + x_vals = sorted(set([x for x, _ in self.pairs] + [x for x, _ in p.pairs])) + y_vals1 = [self(x) for x in x_vals] + y_vals2 = [p(x) for x in x_vals] if include_crossings: extra_x_vals = [] for i in range(len(x_vals) - 1): - if (y_vals1[i] > y_vals2[i]) != (y_vals1[i+1] > y_vals2[i+1]): + if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]): # if the two lines in this subsegment potentially cross each other.. diff_cur = abs(y_vals1[i] - y_vals2[i]) - diff_next = abs(y_vals1[i+1] - y_vals2[i+1]) + diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1]) # `pos`, between 0 and 1, gives the relative x position, # with 0 being x_vals[i] and 1 being x_vals[i+1]. pos = diff_cur / (diff_cur + diff_next) - extra_x_val = x_vals[i] + pos * (x_vals[i+1] - x_vals[i]) + extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i]) extra_x_vals.append(extra_x_val) if len(extra_x_vals) > 0: x_vals = sorted(set(x_vals + extra_x_vals)) - y_vals1 = [ self(x) for x in x_vals ] - y_vals2 = [ p(x) for x in x_vals ] - return ( PiecewiseLinear(* zip(x_vals, y_vals1)), - PiecewiseLinear(* zip(x_vals, y_vals2)) ) + y_vals1 = [self(x) for x in x_vals] + y_vals2 = [p(x) for x in x_vals] + return ( + PiecewiseLinear(*zip(x_vals, y_vals1)), + PiecewiseLinear(*zip(x_vals, y_vals2)), + ) class ScheduledFloat(torch.nn.Module): @@ -176,9 +184,8 @@ class ScheduledFloat(torch.nn.Module): `default` is used when self.batch_count is not set or not in training mode or in torch.jit scripting mode. """ - def __init__(self, - *args, - default: float = 0.0): + + def __init__(self, *args, default: float = 0.0): super().__init__() # self.batch_count and self.name will be written to in the training loop. self.batch_count = None @@ -187,47 +194,55 @@ class ScheduledFloat(torch.nn.Module): self.schedule = PiecewiseLinear(*args) def extra_repr(self) -> str: - return f'batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}' + return ( + f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}" + ) def __float__(self): batch_count = self.batch_count - if batch_count is None or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing(): + if ( + batch_count is None + or not self.training + or torch.jit.is_scripting() + or torch.jit.is_tracing() + ): return float(self.default) else: ans = self.schedule(self.batch_count) if random.random() < 0.0002: - logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}") + logging.info( + f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}" + ) return ans def __add__(self, x): if isinstance(x, float) or isinstance(x, int): - return ScheduledFloat(self.schedule + x, - default=self.default) + return ScheduledFloat(self.schedule + x, default=self.default) else: - return ScheduledFloat(self.schedule + x.schedule, - default=self.default+x.default) + return ScheduledFloat( + self.schedule + x.schedule, default=self.default + x.default + ) def max(self, x): if isinstance(x, float) or isinstance(x, int): - return ScheduledFloat(self.schedule.max(x), - default=self.default) + return ScheduledFloat(self.schedule.max(x), default=self.default) else: - return ScheduledFloat(self.schedule.max(x.schedule), - default=max(self.default, x.default)) + return ScheduledFloat( + self.schedule.max(x.schedule), default=max(self.default, x.default) + ) FloatLike = Union[float, ScheduledFloat] -def random_cast_to_half(x: Tensor, - min_abs: float = 5.0e-06) -> Tensor: +def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: """ A randomized way of casting a floating point value to half precision. """ if x.dtype == torch.float16: return x x_abs = x.abs() - is_too_small = (x_abs < min_abs) + is_too_small = x_abs < min_abs # for elements where is_too_small is true, random_val will contain +-min_abs with # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, # for those elements]. @@ -242,6 +257,7 @@ class CutoffEstimator: p is the proportion of items that should be above the cutoff. """ + def __init__(self, p: float): self.p = p # total count of items @@ -255,7 +271,7 @@ class CutoffEstimator: """ Returns true if x is above the cutoff. """ - ans = (x > self.cutoff) + ans = x > self.cutoff self.count += 1 if ans: self.count_above += 1 @@ -263,7 +279,7 @@ class CutoffEstimator: delta_p = cur_p - self.p if (delta_p > 0) == ans: q = abs(delta_p) - self.cutoff = x * q + self.cutoff * (1-q) + self.cutoff = x * q + self.cutoff * (1 - q) return ans @@ -272,6 +288,7 @@ class SoftmaxFunction(torch.autograd.Function): Tries to handle half-precision derivatives in a randomized way that should be more accurate for training than the default behavior. """ + @staticmethod def forward(ctx, x: Tensor, dim: int): ans = x.softmax(dim=dim) @@ -287,7 +304,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): - ans, = ctx.saved_tensors + (ans,) = ctx.saved_tensors with torch.cuda.amp.autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) @@ -306,17 +323,16 @@ def softmax(x: Tensor, dim: int): class MaxEigLimiterFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - coeffs: Tensor, - direction: Tensor, - channel_dim: int, - grad_scale: float) -> Tensor: + ctx, + x: Tensor, + coeffs: Tensor, + direction: Tensor, + channel_dim: int, + grad_scale: float, + ) -> Tensor: ctx.channel_dim = channel_dim ctx.grad_scale = grad_scale - ctx.save_for_backward(x.detach(), - coeffs.detach(), - direction.detach()) + ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) return x @staticmethod @@ -328,15 +344,20 @@ class MaxEigLimiterFunction(torch.autograd.Function): x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) new_direction.requires_grad = False x = x - x.mean(dim=0) - x_var = (x ** 2).mean() + x_var = (x**2).mean() x_residual = x - coeffs * new_direction - x_residual_var = (x_residual ** 2).mean() + x_residual_var = (x_residual**2).mean() # `variance_proportion` is the proportion of the variance accounted for # by the top eigen-direction. This is to be minimized. variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) variance_proportion.backward() x_orig_grad = x_orig.grad - x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20) + x_extra_grad = ( + x_orig.grad + * ctx.grad_scale + * x_grad.norm() + / (x_orig_grad.norm() + 1.0e-20) + ) return x_grad + x_extra_grad.detach(), None, None, None, None @@ -348,8 +369,14 @@ class BiasNormFunction(torch.autograd.Function): # it can just store the returned value (chances are, this will also be needed for # some other reason, related to the next operation, so we can save memory). @staticmethod - def forward(ctx, x: Tensor, bias: Tensor, log_scale: Tensor, channel_dim: int, - store_output_for_backprop: bool) -> Tensor: + def forward( + ctx, + x: Tensor, + bias: Tensor, + log_scale: Tensor, + channel_dim: int, + store_output_for_backprop: bool, + ) -> Tensor: assert bias.ndim == 1 if channel_dim < 0: channel_dim = channel_dim + x.ndim @@ -357,10 +384,16 @@ class BiasNormFunction(torch.autograd.Function): ctx.channel_dim = channel_dim for _ in range(channel_dim + 1, x.ndim): bias = bias.unsqueeze(-1) - scales = (torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * log_scale.exp() + scales = ( + torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 + ) * log_scale.exp() ans = x * scales - ctx.save_for_backward(ans.detach() if store_output_for_backprop else x, - scales.detach(), bias.detach(), log_scale.detach()) + ctx.save_for_backward( + ans.detach() if store_output_for_backprop else x, + scales.detach(), + bias.detach(), + log_scale.detach(), + ) return ans @staticmethod @@ -376,7 +409,9 @@ class BiasNormFunction(torch.autograd.Function): log_scale.requires_grad = True with torch.enable_grad(): # recompute scales from x, bias and log_scale. - scales = (torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5) * log_scale.exp() + scales = ( + torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5 + ) * log_scale.exp() ans = x * scales ans.backward(gradient=ans_grad) return x.grad, bias.grad.flatten(), log_scale.grad, None, None @@ -412,14 +447,15 @@ class BiasNorm(torch.nn.Module): than the input of this module to be required to be stored for the backprop. """ + def __init__( - self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - log_scale: float = 1.0, - log_scale_min: float = -1.5, - log_scale_max: float = 1.5, - store_output_for_backprop: bool = False + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + log_scale: float = 1.0, + log_scale_min: float = -1.5, + log_scale_max: float = 1.5, + store_output_for_backprop: bool = False, ) -> None: super(BiasNorm, self).__init__() self.num_channels = num_channels @@ -442,23 +478,24 @@ class BiasNorm(torch.nn.Module): bias = self.bias for _ in range(channel_dim + 1, x.ndim): bias = bias.unsqueeze(-1) - scales = ((torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * - self.log_scale.exp()) + scales = ( + torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 + ) * self.log_scale.exp() return x * scales - log_scale = limit_param_value(self.log_scale, - min=float(self.log_scale_min), - max=float(self.log_scale_max), - training=self.training) + log_scale = limit_param_value( + self.log_scale, + min=float(self.log_scale_min), + max=float(self.log_scale_max), + training=self.training, + ) - return BiasNormFunction.apply(x, self.bias, log_scale, - self.channel_dim, - self.store_output_for_backprop) + return BiasNormFunction.apply( + x, self.bias, log_scale, self.channel_dim, self.store_output_for_backprop + ) -def ScaledLinear(*args, - initial_scale: float = 1.0, - **kwargs) -> nn.Linear: +def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: """ Behaves like a constructor of a modified version of nn.Linear that gives an easy way to set the default initial parameter scale. @@ -477,15 +514,11 @@ def ScaledLinear(*args, with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans -def ScaledConv1d(*args, - initial_scale: float = 1.0, - **kwargs) -> nn.Conv1d: +def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: """ Behaves like a constructor of a modified version of nn.Conv1d that gives an easy way to set the default initial parameter scale. @@ -504,15 +537,11 @@ def ScaledConv1d(*args, with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans -def ScaledConv2d(*args, - initial_scale: float = 1.0, - **kwargs) -> nn.Conv2d: +def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: """ Behaves like a constructor of a modified version of nn.Conv2d that gives an easy way to set the default initial parameter scale. @@ -532,9 +561,7 @@ def ScaledConv2d(*args, with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans @@ -562,29 +589,36 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module): Another option, if you want to do something like this, is to re-initialize the parameters. """ - def __init__(self, - channels: int, - kernel_size: int, - initial_scale: float = 1.0, - bias: bool = True): + + def __init__( + self, + channels: int, + kernel_size: int, + initial_scale: float = 1.0, + bias: bool = True, + ): super().__init__() assert kernel_size % 2 == 1 half_kernel_size = (kernel_size + 1) // 2 # will pad manually, on one side. - self.causal_conv = nn.Conv1d(in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=half_kernel_size, - padding=0, - bias=True) + self.causal_conv = nn.Conv1d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=half_kernel_size, + padding=0, + bias=True, + ) - self.chunkwise_conv = nn.Conv1d(in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=kernel_size, - padding=kernel_size // 2, - bias=bias) + self.chunkwise_conv = nn.Conv1d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + bias=bias, + ) # first row is correction factors added to the scale near the left edge of the chunk, # second row is correction factors added to the scale near the right edge of the chunk, @@ -596,17 +630,15 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module): self.causal_conv.weight[:] *= initial_scale self.chunkwise_conv.weight[:] *= initial_scale if bias: - torch.nn.init.uniform_(self.causal_conv.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_( + self.causal_conv.bias, -0.1 * initial_scale, 0.1 * initial_scale + ) - def forward(self, - x: Tensor, - chunk_size: int = -1) -> Tensor: + def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor: """ - Forward function. Args: - x: a Tensor of shape (batch_size, channels, seq_len) - chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. + Forward function. Args: + x: a Tensor of shape (batch_size, channels, seq_len) + chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. """ (batch_size, num_channels, seq_len) = x.shape @@ -622,28 +654,32 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module): x = torch.nn.functional.pad(x, (left_pad, right_pad)) - x_causal = self.causal_conv(x[..., :left_pad + seq_len]) + x_causal = self.causal_conv(x[..., : left_pad + seq_len]) assert x_causal.shape == (batch_size, num_channels, seq_len) x_chunk = x[..., left_pad:] num_chunks = x_chunk.shape[2] // chunk_size x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size) - x_chunk = x_chunk.permute(0, 2, 1, 3).reshape(batch_size * num_chunks, - num_channels, chunk_size) + x_chunk = x_chunk.permute(0, 2, 1, 3).reshape( + batch_size * num_chunks, num_channels, chunk_size + ) x_chunk = self.chunkwise_conv(x_chunk) # does not change shape chunk_scale = self._get_chunk_scale(chunk_size) x_chunk = x_chunk * chunk_scale - x_chunk = x_chunk.reshape(batch_size, num_chunks, - num_channels, chunk_size).permute(0, 2, 1, 3) - x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[..., :seq_len] + x_chunk = x_chunk.reshape( + batch_size, num_chunks, num_channels, chunk_size + ).permute(0, 2, 1, 3) + x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[ + ..., :seq_len + ] return x_chunk + x_causal def _get_chunk_scale(self, chunk_size: int): """Returns tensor of shape (num_channels, chunk_size) that will be used to - scale the output of self.chunkwise_conv.""" + scale the output of self.chunkwise_conv.""" left_edge = self.chunkwise_conv_scale[0] right_edge = self.chunkwise_conv_scale[1] if chunk_size < self.kernel_size: @@ -652,9 +688,9 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module): else: t = chunk_size - self.kernel_size channels = left_edge.shape[0] - pad = torch.zeros(channels, t, - device=left_edge.device, - dtype=left_edge.dtype) + pad = torch.zeros( + channels, t, device=left_edge.device, dtype=left_edge.dtype + ) left_edge = torch.cat((left_edge, pad), dim=-1) right_edge = torch.cat((pad, right_edge), dim=-1) return 1.0 + (left_edge + right_edge) @@ -698,14 +734,14 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module): class BalancerFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - min_mean: float, - max_mean: float, - min_rms: float, - max_rms: float, - grad_scale: float, - channel_dim: int, + ctx, + x: Tensor, + min_mean: float, + max_mean: float, + min_rms: float, + max_rms: float, + grad_scale: float, + channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim @@ -715,10 +751,8 @@ class BalancerFunction(torch.autograd.Function): return x @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None, None, None]: - x, = ctx.saved_tensors + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: + (x,) = ctx.saved_tensors (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config try: @@ -727,8 +761,8 @@ class BalancerFunction(torch.autograd.Function): x = x.to(torch.float32) x = x.detach() x.requires_grad = True - mean_dims = [ i for i in range(x.ndim) if i != channel_dim ] - uncentered_var = (x ** 2).mean(dim=mean_dims, keepdim=True) + mean_dims = [i for i in range(x.ndim) if i != channel_dim] + uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True) mean = x.mean(dim=mean_dims, keepdim=True) stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt() rms = uncentered_var.clamp(min=1.0e-20).sqrt() @@ -742,11 +776,16 @@ class BalancerFunction(torch.autograd.Function): rms_clamped = rms.clamp(min=min_rms, max=max_rms) r_loss = (rms_clamped / rms).log().abs() - loss = (m_loss + r_loss) + loss = m_loss + r_loss loss.backward(gradient=torch.ones_like(loss)) loss_grad = x.grad - loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20) + loss_grad_rms = ( + (loss_grad**2) + .mean(dim=mean_dims, keepdim=True) + .sqrt() + .clamp(min=1.0e-20) + ) loss_grad = loss_grad * (grad_scale / loss_grad_rms) @@ -757,7 +796,9 @@ class BalancerFunction(torch.autograd.Function): x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad) x_grad = x_grad_mod.to(x_grad.dtype) except Exception as e: - logging.info(f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue.") + logging.info( + f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue." + ) return x_grad, None, None, None, None, None, None @@ -793,16 +834,17 @@ class Balancer(torch.nn.Module): on each forward(). This is done randomly to prevent all layers from doing it at the same time. """ + def __init__( - self, - num_channels: int, - channel_dim: int, - min_positive: FloatLike = 0.05, - max_positive: FloatLike = 0.95, - min_abs: FloatLike = 0.2, - max_abs: FloatLike = 100.0, - grad_scale: FloatLike = 0.04, - prob: Optional[FloatLike] = None, + self, + num_channels: int, + channel_dim: int, + min_positive: FloatLike = 0.05, + max_positive: FloatLike = 0.95, + min_abs: FloatLike = 0.2, + max_abs: FloatLike = 100.0, + grad_scale: FloatLike = 0.04, + prob: Optional[FloatLike] = None, ): super().__init__() @@ -823,8 +865,11 @@ class Balancer(torch.nn.Module): self.grad_scale = grad_scale def forward(self, x: Tensor) -> Tensor: - if (torch.jit.is_scripting() or not x.requires_grad or - (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))): + if ( + torch.jit.is_scripting() + or not x.requires_grad + or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated())) + ): return _no_op(x) prob = float(self.prob) @@ -842,7 +887,7 @@ class Balancer(torch.nn.Module): eps = 1.0e-10 # eps is to prevent crashes if x is exactly 0 or 1. # we'll just end up returning a fairly large value. - return (math.log (1+x+eps) - math.log (1-x+eps)) / 2. + return (math.log(1 + x + eps) - math.log(1 - x + eps)) / 2.0 def _approx_inverse_erf(x): # 1 / (sqrt(pi) * ln(2)), @@ -853,6 +898,7 @@ class Balancer(torch.nn.Module): # and math.erf(0.0407316414078772) = 0.045935330944660666, # which is pretty close to 0.05. return 0.8139535143 * _atanh(x) + # first convert x from the range 0..1 to the range -1..1 which the error # function returns x = -1 + (2 * x) @@ -873,8 +919,9 @@ class Balancer(torch.nn.Module): return _no_op(x) -def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float, - name: str = None) -> Tensor: +def penalize_abs_values_gt( + x: Tensor, limit: float, penalty: float, name: str = None +) -> Tensor: """ Returns x unmodified, but in backprop will put a penalty for the excess of the absolute values of elements of x over the limit "limit". E.g. if @@ -910,13 +957,12 @@ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. else: (batch, dim, dim) = x.shape x = x.reshape(batch, dim * dim) - x = x[:, ::dim+1] + x = x[:, :: dim + 1] assert x.shape == (batch, dim) return x -def _whitening_metric(x: Tensor, - num_groups: int): +def _whitening_metric(x: Tensor, num_groups: int): """ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of of the centered feature covariance are the same within each group's covariance matrix @@ -946,25 +992,22 @@ def _whitening_metric(x: Tensor, # the following expression is what we'd get if we took the matrix product # of each covariance and measured the mean of its trace, i.e. # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). - x_covarsq_mean_diag = (x_covar ** 2).sum() / (num_groups * channels_per_group) + x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) # this metric will be >= 1.0; the larger it is, the less 'white' the data was. - metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20) + metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) return metric class WhiteningPenaltyFunction(torch.autograd.Function): @staticmethod - def forward(ctx, - x: Tensor, - module: nn.Module) -> Tensor: + def forward(ctx, x: Tensor, module: nn.Module) -> Tensor: ctx.save_for_backward(x) ctx.module = module return x @staticmethod - def backward(ctx, - x_grad: Tensor): - x_orig, = ctx.saved_tensors + def backward(ctx, x_grad: Tensor): + (x_orig,) = ctx.saved_tensors w = ctx.module try: @@ -976,8 +1019,10 @@ class WhiteningPenaltyFunction(torch.autograd.Function): metric = _whitening_metric(x_detached, w.num_groups) if random.random() < 0.005 or __name__ == "__main__": - logging.info(f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, " - f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}") + logging.info( + f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, " + f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}" + ) if metric < float(w.whitening_limit): w.prob = w.min_prob @@ -986,22 +1031,27 @@ class WhiteningPenaltyFunction(torch.autograd.Function): w.prob = w.max_prob metric.backward() penalty_grad = x_detached.grad - scale = w.grad_scale * (x_grad.to(torch.float32).norm() / - (penalty_grad.norm() + 1.0e-20)) + scale = w.grad_scale * ( + x_grad.to(torch.float32).norm() + / (penalty_grad.norm() + 1.0e-20) + ) penalty_grad = penalty_grad * scale return x_grad + penalty_grad.to(x_grad.dtype), None except Exception as e: - logging.info(f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue.") + logging.info( + f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue." + ) return x_grad, None class Whiten(nn.Module): def __init__( - self, - num_groups: int, - whitening_limit: FloatLike, - prob: Union[float, Tuple[float,float]], - grad_scale: FloatLike): + self, + num_groups: int, + whitening_limit: FloatLike, + prob: Union[float, Tuple[float, float]], + grad_scale: FloatLike, + ): """ Args: num_groups: the number of groups to divide the channel dim into before @@ -1033,10 +1083,9 @@ class Whiten(nn.Module): (self.min_prob, self.max_prob) = prob assert 0 < self.min_prob <= self.max_prob <= 1 self.prob = self.max_prob - self.name = None # will be set in training loop + self.name = None # will be set in training loop - def forward(self, - x: Tensor) -> Tensor: + def forward(self, x: Tensor) -> Tensor: """ In the forward pass, this function just returns the input unmodified. In the backward pass, it will modify the gradients to ensure that the @@ -1071,9 +1120,11 @@ class WithLoss(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): - return ans_grad, torch.ones(ctx.y_shape, - dtype=ans_grad.dtype, - device=ans_grad.device), None + return ( + ans_grad, + torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device), + None, + ) def with_loss(x, y, name): @@ -1118,20 +1169,21 @@ class LimitParamValue(torch.autograd.Function): @staticmethod def backward(ctx, x_grad: Tensor): - x, = ctx.saved_tensors + (x,) = ctx.saved_tensors # where x < ctx.min, ensure all grads are negative (this will tend to make # x more positive). - x_grad = x_grad * torch.where(torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0) + x_grad = x_grad * torch.where( + torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0 + ) # where x > ctx.max, ensure all grads are positive (this will tend to make # x more negative). x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0) return x_grad, None, None -def limit_param_value(x: Tensor, - min: float, max: float, - prob: float = 0.6, - training: bool = True): +def limit_param_value( + x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True +): # You apply this to (typically) an nn.Parameter during training to ensure that its # (elements mostly) stays within a supplied range. This is done by modifying the # gradients in backprop. @@ -1187,7 +1239,7 @@ class DoubleSwishFunction(torch.autograd.Function): y = x * s if requires_grad: - deriv = (y * (1 - s) + s) + deriv = y * (1 - s) + s # notes on derivative of x * sigmoid(x - 1): # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 @@ -1197,7 +1249,9 @@ class DoubleSwishFunction(torch.autograd.Function): # floors), should be expectation-preserving. floor = -0.044 ceil = 1.2 - d_scaled = ((deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)) + d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + deriv + ) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -1210,12 +1264,12 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors + (d,) = ctx.saved_tensors # the same constants as used in forward pass. floor = -0.043637 ceil = 1.2 - d = (d * ((ceil - floor) / 255.0) + floor) + d = d * ((ceil - floor) / 255.0) + floor return y_grad * d @@ -1239,9 +1293,7 @@ class Dropout2(nn.Module): self.p = p def forward(self, x: Tensor) -> Tensor: - return torch.nn.functional.dropout(x, - p=float(self.p), - training=self.training) + return torch.nn.functional.dropout(x, p=float(self.p), training=self.training) class MulForDropout3(torch.autograd.Function): @@ -1259,7 +1311,7 @@ class MulForDropout3(torch.autograd.Function): @staticmethod @custom_bwd def backward(ctx, ans_grad): - ans, = ctx.saved_tensors + (ans,) = ctx.saved_tensors x_grad = ctx.alpha * ans_grad * (ans != 0) return x_grad, None, None @@ -1286,7 +1338,7 @@ class Dropout3(nn.Module): class SwooshLFunction(torch.autograd.Function): """ - swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 + swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 """ @staticmethod @@ -1308,13 +1360,15 @@ class SwooshLFunction(torch.autograd.Function): if not requires_grad: return y - y.backward(gradient = torch.ones_like(y)) + y.backward(gradient=torch.ones_like(y)) grad = x.grad floor = coeff ceil = 1.0 + coeff + 0.005 - d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad)) + d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + grad + ) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -1328,20 +1382,19 @@ class SwooshLFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors + (d,) = ctx.saved_tensors # the same constants as used in forward pass. coeff = -0.08 floor = coeff ceil = 1.0 + coeff + 0.005 - d = (d * ((ceil - floor) / 255.0) + floor) - return (y_grad * d) + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d class SwooshL(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-L activation. - """ + """Return Swoosh-L activation.""" if torch.jit.is_scripting() or torch.jit.is_tracing(): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 @@ -1351,19 +1404,19 @@ class SwooshL(torch.nn.Module): return k2.swoosh_l(x) # return SwooshLFunction.apply(x) + class SwooshLOnnx(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-L activation. - """ + """Return Swoosh-L activation.""" zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035 class SwooshRFunction(torch.autograd.Function): """ - swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 + swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 - derivatives are between -0.08 and 0.92. + derivatives are between -0.08 and 0.92. """ @staticmethod @@ -1379,17 +1432,19 @@ class SwooshRFunction(torch.autograd.Function): with torch.enable_grad(): x = x.detach() x.requires_grad = True - y = torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 + y = torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 if not requires_grad: return y - y.backward(gradient = torch.ones_like(y)) + y.backward(gradient=torch.ones_like(y)) grad = x.grad floor = -0.08 ceil = 0.925 - d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad)) + d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + grad + ) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -1403,33 +1458,32 @@ class SwooshRFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors + (d,) = ctx.saved_tensors # the same constants as used in forward pass. floor = -0.08 ceil = 0.925 - d = (d * ((ceil - floor) / 255.0) + floor) - return (y_grad * d) + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d class SwooshR(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-R activation. - """ + """Return Swoosh-R activation.""" if torch.jit.is_scripting() or torch.jit.is_tracing(): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 + return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 if not x.requires_grad: return k2.swoosh_r_forward(x) else: return k2.swoosh_r(x) # return SwooshRFunction.apply(x) + class SwooshROnnx(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-R activation. - """ + """Return Swoosh-R activation.""" zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return logaddexp_onnx(zero, x - 1.) - 0.08 * x - 0.313261687 + return logaddexp_onnx(zero, x - 1.0) - 0.08 * x - 0.313261687 # simple version of SwooshL that does not redefine the backprop, used in @@ -1437,7 +1491,7 @@ class SwooshROnnx(torch.nn.Module): def SwooshLForward(x: Tensor): x_offset = x - 4.0 log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) - log_sum = torch.where(log_sum == float('inf'), x_offset, log_sum) + log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) return log_sum - 0.08 * x - 0.035 @@ -1446,28 +1500,30 @@ def SwooshLForward(x: Tensor): def SwooshRForward(x: Tensor): x_offset = x - 1.0 log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) - log_sum = torch.where(log_sum == float('inf'), x_offset, log_sum) + log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) return log_sum - 0.08 * x - 0.313261687 class ActivationDropoutAndLinearFunction(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, - x: Tensor, - weight: Tensor, - bias: Optional[Tensor], - activation: str, - dropout_p: float, - dropout_shared_dim: Optional[int]): + def forward( + ctx, + x: Tensor, + weight: Tensor, + bias: Optional[Tensor], + activation: str, + dropout_p: float, + dropout_shared_dim: Optional[int], + ): if dropout_p != 0.0: dropout_shape = list(x.shape) if dropout_shared_dim is not None: dropout_shape[dropout_shared_dim] = 1 # else it won't be very memory efficient. - dropout_mask = ((1.0 / (1.0 - dropout_p)) * - (torch.rand(*dropout_shape, - device=x.device, dtype=x.dtype) > dropout_p)) + dropout_mask = (1.0 / (1.0 - dropout_p)) * ( + torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) > dropout_p + ) else: dropout_mask = None @@ -1476,8 +1532,8 @@ class ActivationDropoutAndLinearFunction(torch.autograd.Function): ctx.activation = activation forward_activation_dict = { - 'SwooshL': k2.swoosh_l_forward, - 'SwooshR': k2.swoosh_r_forward + "SwooshL": k2.swoosh_l_forward, + "SwooshR": k2.swoosh_r_forward, } # it will raise a KeyError if this fails. This will be an error. We let it # propagate to the user. @@ -1495,8 +1551,8 @@ class ActivationDropoutAndLinearFunction(torch.autograd.Function): (x, weight, bias, dropout_mask) = saved forward_and_deriv_activation_dict = { - 'SwooshL': k2.swoosh_l_forward_and_deriv, - 'SwooshR': k2.swoosh_r_forward_and_deriv + "SwooshL": k2.swoosh_l_forward_and_deriv, + "SwooshR": k2.swoosh_r_forward_and_deriv, } # the following lines a KeyError if the activation is unrecognized. # This will be an error. We let it propagate to the user. @@ -1511,8 +1567,7 @@ class ActivationDropoutAndLinearFunction(torch.autograd.Function): in_channels = y.shape[-1] g = ans_grad.reshape(-1, out_channels) - weight_deriv = torch.matmul(g.t(), - y.reshape(-1, in_channels)) + weight_deriv = torch.matmul(g.t(), y.reshape(-1, in_channels)) y_deriv = torch.matmul(ans_grad, weight) bias_deriv = None if bias is None else g.sum(dim=0) x_deriv = y_deriv * func_deriv @@ -1525,71 +1580,76 @@ class ActivationDropoutAndLinearFunction(torch.autograd.Function): class ActivationDropoutAndLinear(torch.nn.Module): """ - This merges an activation function followed by dropout and then a nn.Linear module; - it does so in a memory efficient way so that it only stores the input to the whole - module. If activation == SwooshL and dropout_shared_dim != None, this will be - equivalent to: - nn.Sequential(SwooshL(), - Dropout3(dropout_p, shared_dim=dropout_shared_dim), - ScaledLinear(in_channels, out_channels, bias=bias, - initial_scale=initial_scale)) - If dropout_shared_dim is None, the dropout would be equivalent to - Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout - mask is smaller. + This merges an activation function followed by dropout and then a nn.Linear module; + it does so in a memory efficient way so that it only stores the input to the whole + module. If activation == SwooshL and dropout_shared_dim != None, this will be + equivalent to: + nn.Sequential(SwooshL(), + Dropout3(dropout_p, shared_dim=dropout_shared_dim), + ScaledLinear(in_channels, out_channels, bias=bias, + initial_scale=initial_scale)) + If dropout_shared_dim is None, the dropout would be equivalent to + Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout + mask is smaller. - Args: - in_channels: number of input channels, e.g. 256 - out_channels: number of output channels, e.g. 256 - bias: if true, have a bias - activation: the activation function, for now just support SwooshL. - dropout_p: the dropout probability or schedule (happens after nonlinearity). - dropout_shared_dim: the dimension, if any, across which the dropout mask is - shared (e.g. the time dimension). If None, this may be less memory - efficient if there are modules before this one that cache the input - for their backprop (e.g. Balancer or Whiten). + Args: + in_channels: number of input channels, e.g. 256 + out_channels: number of output channels, e.g. 256 + bias: if true, have a bias + activation: the activation function, for now just support SwooshL. + dropout_p: the dropout probability or schedule (happens after nonlinearity). + dropout_shared_dim: the dimension, if any, across which the dropout mask is + shared (e.g. the time dimension). If None, this may be less memory + efficient if there are modules before this one that cache the input + for their backprop (e.g. Balancer or Whiten). """ - def __init__(self, - in_channels: int, - out_channels: int, - bias: bool = True, - activation: str = 'SwooshL', - dropout_p: FloatLike = 0.0, - dropout_shared_dim: Optional[int] = -1, - initial_scale: float = 1.0): + + def __init__( + self, + in_channels: int, + out_channels: int, + bias: bool = True, + activation: str = "SwooshL", + dropout_p: FloatLike = 0.0, + dropout_shared_dim: Optional[int] = -1, + initial_scale: float = 1.0, + ): super().__init__() # create a temporary module of nn.Linear that we'll steal the # weights and bias from - l = ScaledLinear(in_channels, out_channels, - bias=bias, - initial_scale=initial_scale) + l = ScaledLinear( + in_channels, out_channels, bias=bias, initial_scale=initial_scale + ) self.weight = l.weight # register_parameter properly handles making it a parameter when l.bias # is None. I think there is some reason for doing it this way rather # than just setting it to None but I don't know what it is, maybe # something to do with exporting the module.. - self.register_parameter('bias', l.bias) + self.register_parameter("bias", l.bias) self.activation = activation self.dropout_p = dropout_p self.dropout_shared_dim = dropout_shared_dim - def forward(self, - x: Tensor): + def forward(self, x: Tensor): if torch.jit.is_scripting() or torch.jit.is_tracing(): - if self.activation == 'SwooshL': + if self.activation == "SwooshL": x = SwooshLForward(x) elif self.activation == "SwooshR": x = SwooshRForward(x) else: assert False, self.activation - return torch.nn.functional.linear(x, - self.weight, - self.bias) + return torch.nn.functional.linear(x, self.weight, self.bias) return ActivationDropoutAndLinearFunction.apply( - x, self.weight, self.bias, self.activation, - float(self.dropout_p), self.dropout_shared_dim) + x, + self.weight, + self.bias, + self.activation, + float(self.dropout_p), + self.dropout_shared_dim, + ) def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: @@ -1612,10 +1672,9 @@ def _test_whiten(): x.requires_grad = True - m = Whiten(1, # num_groups - 5.0, # whitening_limit, - prob=1.0, - grad_scale=0.1) # grad_scale + m = Whiten( + 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, + ) # grad_scale for _ in range(4): y = m(x) @@ -1656,9 +1715,7 @@ def _test_balancer_sign(): def _test_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True m = Balancer( @@ -1685,7 +1742,7 @@ def _test_double_swish_deriv(): x.requires_grad = True m = DoubleSwish() - tol = ((1.2-(-0.043637))/255.0) + tol = (1.2 - (-0.043637)) / 255.0 torch.autograd.gradcheck(m, x, atol=tol) # for self-test. @@ -1699,7 +1756,7 @@ def _test_swooshl_deriv(): x.requires_grad = True m = SwooshL() - tol = (1.0 / 255.0) + tol = 1.0 / 255.0 torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) # for self-test. @@ -1713,7 +1770,7 @@ def _test_swooshr_deriv(): x.requires_grad = True m = SwooshR() - tol = (1.0 / 255.0) + tol = 1.0 / 255.0 torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) # for self-test. @@ -1727,24 +1784,24 @@ def _test_softmax(): b = a.clone() a.requires_grad = True b.requires_grad = True - a.softmax(dim=1)[:,0].sum().backward() + a.softmax(dim=1)[:, 0].sum().backward() print("a grad = ", a.grad) - softmax(b, dim=1)[:,0].sum().backward() + softmax(b, dim=1)[:, 0].sum().backward() print("b grad = ", b.grad) assert torch.allclose(a.grad, b.grad) def _test_piecewise_linear(): - p = PiecewiseLinear( (0, 10.0) ) + p = PiecewiseLinear((0, 10.0)) for x in [-100, 0, 100]: assert p(x) == 10.0 - p = PiecewiseLinear( (0, 10.0), (1, 0.0) ) - for x, y in [ (-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0) ]: + p = PiecewiseLinear((0, 10.0), (1, 0.0)) + for x, y in [(-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0)]: print("x, y = ", x, y) assert p(x) == y, (x, p(x), y) q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0)) - x_vals = [ -1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0 ] + x_vals = [-1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0] pq = p.max(q) for x in x_vals: y1 = max(p(x), q(x)) @@ -1757,7 +1814,7 @@ def _test_piecewise_linear(): assert abs(y1 - y2) < 0.001 pq = p + q for x in x_vals: - y1 = p(x) + q(x) + y1 = p(x) + q(x) y2 = pq(x) assert abs(y1 - y2) < 0.001 @@ -1772,15 +1829,22 @@ def _test_activation_dropout_and_linear(): # swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn() # internally, messing up the random state. for dropout_p in [0.0]: - for activation in ['SwooshL', 'SwooshR']: - m1 = nn.Sequential(SwooshL() if activation == 'SwooshL' else SwooshR(), - Dropout3(p=dropout_p, shared_dim=-1), - ScaledLinear(in_channels, out_channels, bias=bias, - initial_scale=0.5)) - m2 = ActivationDropoutAndLinear(in_channels, out_channels, - bias=bias, initial_scale=0.5, - activation=activation, - dropout_p=dropout_p) + for activation in ["SwooshL", "SwooshR"]: + m1 = nn.Sequential( + SwooshL() if activation == "SwooshL" else SwooshR(), + Dropout3(p=dropout_p, shared_dim=-1), + ScaledLinear( + in_channels, out_channels, bias=bias, initial_scale=0.5 + ), + ) + m2 = ActivationDropoutAndLinear( + in_channels, + out_channels, + bias=bias, + initial_scale=0.5, + activation=activation, + dropout_p=dropout_p, + ) with torch.no_grad(): m2.weight[:] = m1[2].weight if bias: @@ -1790,9 +1854,9 @@ def _test_activation_dropout_and_linear(): x1.requires_grad = True # TEMP. - assert torch.allclose(SwooshRFunction.apply(x1), - SwooshRForward(x1), - atol=1.0e-03) + assert torch.allclose( + SwooshRFunction.apply(x1), SwooshRForward(x1), atol=1.0e-03 + ) x2 = x1.clone().detach() x2.requires_grad = True @@ -1805,21 +1869,24 @@ def _test_activation_dropout_and_linear(): y2 = m2(x2) y2.backward(gradient=y_grad) - print(f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}") + print( + f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}" + ) print("y1 = ", y1) print("y2 = ", y2) assert torch.allclose(y1, y2, atol=0.02) - assert torch.allclose(m1[2].weight.grad, m2.weight.grad, - atol=1.0e-05) + assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05) if bias: - assert torch.allclose(m1[2].bias.grad, m2.bias.grad, - atol=1.0e-05) + assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05) print("x1.grad = ", x1.grad) print("x2.grad = ", x2.grad) def isclose(a, b): # return true if cosine similarity is > 0.9. - return (a * b).sum() > 0.9 * ((a**2).sum() * (b**2).sum()).sqrt() + return (a * b).sum() > 0.9 * ( + (a**2).sum() * (b**2).sum() + ).sqrt() + # the SwooshL() implementation has a noisy gradient due to 1-byte # storage of it. assert isclose(x1.grad, x2.grad) diff --git a/egs/librispeech/ASR/zipformer/streaming_decode.py b/egs/librispeech/ASR/zipformer/streaming_decode.py index 44ff392a3..904caf8af 100755 --- a/egs/librispeech/ASR/zipformer/streaming_decode.py +++ b/egs/librispeech/ASR/zipformer/streaming_decode.py @@ -282,9 +282,7 @@ def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: ) batch_states.append(cached_embed_left_pad) - processed_lens = torch.cat( - [state_list[i][-1] for i in range(batch_size)], dim=0 - ) + processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) batch_states.append(processed_lens) return batch_states @@ -322,9 +320,7 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: for layer in range(tot_num_layers): layer_offset = layer * 6 # cached_key: (left_context_len, batch_size, key_dim) - cached_key_list = batch_states[layer_offset].chunk( - chunks=batch_size, dim=1 - ) + cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( chunks=batch_size, dim=1 @@ -355,9 +351,7 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: cached_conv2_list[i], ] - cached_embed_left_pad_list = batch_states[-2].chunk( - chunks=batch_size, dim=0 - ) + cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) for i in range(batch_size): state_list[i].append(cached_embed_left_pad_list[i]) @@ -380,11 +374,7 @@ def streaming_forward( Returns encoder outputs, output lengths, and updated states. """ cached_embed_left_pad = states[-2] - ( - x, - x_lens, - new_cached_embed_left_pad, - ) = model.encoder_embed.streaming_forward( + (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( x=features, x_lens=feature_lens, cached_left_pad=cached_embed_left_pad, @@ -404,9 +394,7 @@ def streaming_forward( new_processed_lens = processed_lens + x_lens # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat( - [processed_mask, src_key_padding_mask], dim=1 - ) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) encoder_states = states[:-2] @@ -494,9 +482,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = torch.tensor(processed_lens, device=device) processed_lens = processed_lens + encoder_out_lens @@ -517,9 +503,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = unstack_states(new_states) @@ -577,9 +561,7 @@ def decode_dataset( decode_streams = [] for num, cut in enumerate(cuts): # each utterance has a DecodeStream. - initial_states = get_init_states( - model=model, batch_size=1, device=device - ) + initial_states = get_init_states(model=model, batch_size=1, device=device) decode_stream = DecodeStream( params=params, cut_id=cut.id, @@ -649,9 +631,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -684,8 +664,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -718,9 +697,7 @@ def main(): params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" assert params.causal, params.causal - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." + assert "," not in params.chunk_size, "chunk_size should be one value in decoding." assert ( "," not in params.left_context_frames ), "left_context_frames should be one value in decoding." @@ -760,9 +737,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -789,9 +766,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 6532ddccb..d16d87bac 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -107,9 +107,7 @@ class ConvNeXt(nn.Module): if layerdrop_rate != 0.0: batch_size = x.shape[0] mask = ( - torch.rand( - (batch_size, 1, 1, 1), dtype=x.dtype, device=x.device - ) + torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) > layerdrop_rate ) else: @@ -278,9 +276,7 @@ class Conv2dSubsampling(nn.Module): # many copies of this extra gradient term. self.out_whiten = Whiten( num_groups=1, - whitening_limit=ScheduledFloat( - (0.0, 4.0), (20000.0, 8.0), default=4.0 - ), + whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0), prob=(0.025, 0.25), grad_scale=0.02, ) @@ -331,7 +327,7 @@ class Conv2dSubsampling(nn.Module): with warnings.catch_warnings(): warnings.simplefilter("ignore") x_lens = (x_lens - 7) // 2 - assert x.size(1) == x_lens.max().item() , (x.size(1), x_lens.max()) + assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max()) return x, x_lens @@ -403,8 +399,8 @@ class Conv2dSubsampling(nn.Module): left_pad = self.convnext.padding[0] freq = self.out_width channels = self.layer3_channels - cached_embed_left_pad = torch.zeros( - batch_size, channels, left_pad, freq - ).to(device) + cached_embed_left_pad = torch.zeros(batch_size, channels, left_pad, freq).to( + device + ) return cached_embed_left_pad diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index bc3e9c1ba..7009f3346 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -604,11 +604,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: def get_model(params: AttributeDict) -> nn.Module: - assert ( - params.use_transducer or params.use_ctc - ), (f"At least one of them should be True, " + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " f"but got params.use_transducer={params.use_transducer}, " - f"params.use_ctc={params.use_ctc}") + f"params.use_ctc={params.use_ctc}" + ) encoder_embed = get_encoder_embed(params) encoder = get_encoder_model(params) @@ -808,17 +808,16 @@ def compute_loss( # take down the scale on the simple loss from 1.0 at the start # to params.simple_loss scale by warm_step. simple_loss_scale = ( - s if batch_idx_train >= warm_step + s + if batch_idx_train >= warm_step else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) ) pruned_loss_scale = ( - 1.0 if batch_idx_train >= warm_step + 1.0 + if batch_idx_train >= warm_step else 0.1 + 0.9 * (batch_idx_train / warm_step) ) - loss += ( - simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss - ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss if params.use_ctc: loss += params.ctc_loss_scale * ctc_loss @@ -1166,7 +1165,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/zipformer_mmi/train.py b/egs/librispeech/ASR/zipformer_mmi/train.py index c1b3ea3e0..4b50acdde 100755 --- a/egs/librispeech/ASR/zipformer_mmi/train.py +++ b/egs/librispeech/ASR/zipformer_mmi/train.py @@ -981,7 +981,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/WSASR/README.md b/egs/librispeech/WSASR/README.md new file mode 100644 index 000000000..3b8822fd2 --- /dev/null +++ b/egs/librispeech/WSASR/README.md @@ -0,0 +1,224 @@ +# Introduction + +This is a weakly supervised ASR recipe for the LibriSpeech (clean 100 hours) dataset. We train a +conformer model using [Bypass Temporal Classification](https://arxiv.org/pdf/2306.01031.pdf) (BTC)/[Omni-temporal Classification](https://arxiv.org/pdf/2309.15796.pdf) (OTC) with transcripts with synthetic errors. In this README, we will describe +the task and the BTC/OTC training process. + +Note that OTC is an extension of BTC and supports all BTC functions. Therefore, in the following, we only describe OTC. +## Task +We propose BTC/OTC to directly train an ASR system leveraging weak supervision, i.e., speech with non-verbatim transcripts. This is achieved by using a special token $\star$ to model uncertainties (i.e., substitution errors, insertion errors, and deletion errors) +within the WFST framework during training. + + +
+
+ Image 1 + +
+
+ Image 2 + +
+
+ Image 3 + +
+
+
Examples of errors (substitution, insertion, and deletion) in the transcript. The grey box is the verbatim transcript and the red box is the inaccurate transcript. Inaccurate words are marked in bold.


+ + +We modify $G(\mathbf{y})$ by adding self-loop arcs into each state and bypass arcs into each arc. +

+ Image Alt Text + +

+ +We incorporate the penalty strategy and apply different configurations for the self-loop arc and bypass arc. The penalties are set as + +$$\lambda_{1_{i}} = \beta_{1} * \tau_{1}^{i},\quad \lambda_{2_{i}} = \beta_{2} * \tau_{2}^{i}$$ + +for the $i$-th training epoch. $\beta$ is the initial penalty that encourages the model to rely more on the given transcript at the start of training. +It decays exponentially by a factor of $\tau \in (0, 1)$, gradually encouraging the model to align speech with $\star$ when getting confused. + +After composing the modified WFST $G_{\text{otc}}(\mathbf{y})$ with $L$ and $T$, the OTC training graph is shown in this figure: +
+ Image Alt Text +
OTC training graph. The self-loop arcs and bypass arcs are highlighted in green and blue, respectively.
+
+ +The $\star$ is represented as the average probability of all non-blank tokens. +

+ +

+ +The weight of $\star$ is the log average probability of "a" and "b": $\log \frac{e^{-1.2} + e^{-2.3}}{2} = -1.6$ and $\log \frac{e^{-1.9} + e^{-0.5}}{2} = -1.0$ for 2 frames. + +## Description of the recipe +### Preparation +``` +# feature_type can be ssl or fbank +feature_type=ssl +feature_dir="data/${feature_type}" +manifest_dir="${feature_dir}" +lang_dir="data/lang" +lm_dir="data/lm" +exp_dir="conformer_ctc2/exp" +otc_token="" + +./prepare.sh \ + --feature-type "${feature_type}" \ + --feature-dir "${feature_dir}" \ + --lang-dir "${lang_dir}" \ + --lm-dir "${lm_dir}" \ + --otc-token "${otc_token}" +``` +This script adds the 'otc_token' ('\') and its corresponding sentence-piece ('▁\') to 'words.txt' and 'tokens.txt,' respectively. Additionally, it computes SSL features using the 'wav2vec2-base' model. (You can use GPU to accelerate feature extraction). + +### Making synthetic errors to the transcript (train-clean-100) [optional] +``` +sub_er=0.17 +ins_er=0.17 +del_er=0.17 +synthetic_train_manifest="librispeech_cuts_train-clean-100_${sub_er}_${ins_er}_${del_er}.jsonl.gz" + +./local/make_error_cutset.py \ + --input-cutset "${manifest_dir}/librispeech_cuts_train-clean-100.jsonl.gz" \ + --words-file "${lang_dir}/words.txt" \ + --sub-error-rate "${sub_er}" \ + --ins-error-rate "${ins_er}" \ + --del-error-rate "${del_er}" \ + --output-cutset "${manifest_dir}/${synthetic_train_manifest}" +``` +This script generates synthetic substitution, insertion, and deletion errors in the transcript with ratios 'sub_er', 'ins_er', and 'del_er', respectively. The original transcript is saved as 'verbatim transcript' in the cutset, along with information on how the transcript is corrupted: + + - '[hello]' indicates the original word 'hello' is substituted by another word + - '[]' indicates an extra word is inserted into the transcript + - '-hello-' indicates the word 'hello' is deleted from the transcript + +So if the original transcript is "have a nice day" and the synthetic one is "a very good day", the 'verbatim transcript' would be: +``` +original: have a nice day +synthetic: a very good day +verbatim: -have- a [] [nice] day +``` + +### Training +The training uses synthetic data based on the train-clean-100 subset. +``` +otc_lang_dir=data/lang_bpe_200 + +allow_bypass_arc=true +allow_self_loop_arc=true +initial_bypass_weight=-19 +initial_self_loop_weight=3.75 +bypass_weight_decay=0.975 +self_loop_weight_decay=0.999 + +show_alignment=true + +export CUDA_VISIBLE_DEVICES="0,1,2,3" +./conformer_ctc2/train.py \ + --world-size 4 \ + --manifest-dir "${manifest_dir}" \ + --train-manifest "${synthetic_train_manifest}" \ + --exp-dir "${exp_dir}" \ + --lang-dir "${otc_lang_dir}" \ + --otc-token "${otc_token}" \ + --allow-bypass-arc "${allow_bypass_arc}" \ + --allow-self-loop-arc "${allow_self_loop_arc}" \ + --initial-bypass-weight "${initial_bypass_weight}" \ + --initial-self-loop-weight "${initial_self_loop_weight}" \ + --bypass-weight-decay "${bypass_weight_decay}" \ + --self-loop-weight-decay "${self_loop_weight_decay}" \ + --show-alignment "${show_alignment}" +``` +The bypass arc deals with substitution and insertion errors, while the self-loop arc deals with deletion errors. Using "--show-alignment" would print the best alignment during training, which is very helpful for tuning hyperparameters and debugging. + +### Decoding +``` +export CUDA_VISIBLE_DEVICES="0" +./conformer_ctc2/decode.py \ + --manifest-dir "${manifest_dir}" \ + --exp-dir "${exp_dir}" \ + --lang-dir "${otc_lang_dir}" \ + --lm-dir "${lm_dir}" \ + --otc-token "${otc_token}" +``` + +### Results (ctc-greedy-search) + + + + + + + + + + + + + + + + + + + + + + + + + + +
Training Criterionsslfbank
test-cleantest-othertest-cleantest-other
CTC100.0100.099.8999.98
OTC11.8925.4620.1444.24
+ +### Results (1best, blank_bias=-4) + + + + + + + + + + + + + + + + + + + + + + + + + + +
Training Criterionsslfbank
test-cleantest-othertest-cleantest-other
CTC98.4098.6899.7999.86
OTC6.5915.9811.7832.38
+ +## Pre-trained Model +Pre-trained model: + +## Citations +``` +@inproceedings{gao2023bypass, + title={Bypass Temporal Classification: Weakly Supervised Automatic Speech Recognition with Imperfect Transcripts}, + author={Gao, Dongji and Wiesner, Matthew and Xu, Hainan and Garcia, Leibny Paola and Povey, Daniel and Khudanpur, Sanjeev}, + booktitle={INTERSPEECH}, + year={2023} +} + +@inproceedings{gao2023learning, + title={Learning from Flawed Data: Weakly Supervised Automatic Speech Recognition}, + author={Gao, Dongji and Xu, Hainan and Raj, Desh and Garcia, Leibny Paola and Povey, Daniel and Khudanpur, Sanjeev}, + booktitle={IEEE ASRU}, + year={2023} +} +``` diff --git a/egs/librispeech/WSASR/conformer_ctc2/__init__.py b/egs/librispeech/WSASR/conformer_ctc2/__init__.py new file mode 120000 index 000000000..43a85af20 --- /dev/null +++ b/egs/librispeech/WSASR/conformer_ctc2/__init__.py @@ -0,0 +1 @@ +../../ASR/pruned_transducer_stateless2/__init__.py \ No newline at end of file diff --git a/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py b/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py new file mode 100644 index 000000000..1b6991bcd --- /dev/null +++ b/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py @@ -0,0 +1,369 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# 2023 John Hopkins University (author: Dongji Gao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import AudioSamples # noqa F401 For AudioSamples +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LibriSpeechAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=False, + help="""Used only when --mini-libri is False.When enabled, + use 960h LibriSpeech. Otherwise, use 100h subset.""", + ) + group.add_argument( + "--mini-libri", + type=str2bool, + default=False, + help="True for mini librispeech", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/ssl"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + group.add_argument( + "--train-manifest", + type=str, + default="librispeech_cuts_train-clean-100.jsonl.gz", + help="Train manifest file.", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_clean_5_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get train-clean-5 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz" + ) + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy(self.args.manifest_dir / self.args.train_manifest) + + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" + ) + + @lru_cache() + def dev_clean_2_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get dev-clean-2 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" + ) diff --git a/egs/librispeech/WSASR/conformer_ctc2/attention.py b/egs/librispeech/WSASR/conformer_ctc2/attention.py new file mode 120000 index 000000000..e808a6f20 --- /dev/null +++ b/egs/librispeech/WSASR/conformer_ctc2/attention.py @@ -0,0 +1 @@ +../../ASR/conformer_ctc2/attention.py \ No newline at end of file diff --git a/egs/librispeech/WSASR/conformer_ctc2/conformer.py b/egs/librispeech/WSASR/conformer_ctc2/conformer.py new file mode 100644 index 000000000..db4821d37 --- /dev/null +++ b/egs/librispeech/WSASR/conformer_ctc2/conformer.py @@ -0,0 +1,949 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# 2022 Xiaomi Corp. (author: Quandong Wang) +# 2023 Johns Hopkins University (author: Dongji Gao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +import warnings +from typing import Optional, Tuple + +import torch +from scaling import ( + ActivationBalancer, + BasicNorm, + DoubleSwish, + ScaledConv1d, + ScaledLinear, +) +from subsampling import Conv2dSubsampling, Conv2dSubsampling2 +from torch import Tensor, nn +from transformer import Supervisions, Transformer, encoder_padding_mask + + +class Conformer(Transformer): + """ + Args: + num_features (int): Number of input features + num_classes (int): Number of output classes + subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) + d_model (int): attention dimension, also the output dimension + nhead (int): number of head + dim_feedforward (int): feedforward dimention + num_encoder_layers (int): number of encoder layers + num_decoder_layers (int): number of decoder layers + dropout (float): dropout rate + layer_dropout (float): layer-dropout rate. + cnn_module_kernel (int): Kernel size of convolution module + vgg_frontend (bool): whether to use vgg frontend. + """ + + def __init__( + self, + num_features: int, + num_classes: int, + subsampling_factor: int = 2, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + num_decoder_layers: int = 6, + dropout: float = 0.2, + layer_dropout: float = 0.075, + cnn_module_kernel: int = 31, + ) -> None: + super(Conformer, self).__init__( + num_features=num_features, + num_classes=num_classes, + subsampling_factor=subsampling_factor, + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + dropout=dropout, + layer_dropout=layer_dropout, + ) + + self.num_features = num_features + self.subsampling_factor = subsampling_factor + if subsampling_factor != 4 and subsampling_factor != 2: + raise NotImplementedError("Support only 'subsampling_factor=4 or 2'.") + + # self.encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_features -> d_model + if self.subsampling_factor == 4: + self.encoder_embed = Conv2dSubsampling(num_features, d_model) + elif self.subsampling_factor == 2: + self.encoder_embed = Conv2dSubsampling2(num_features, d_model) + + self.encoder_pos = RelPositionalEncoding(d_model, dropout) + + encoder_layer = ConformerEncoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + layer_dropout, + cnn_module_kernel, + ) + self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) + + def run_encoder( + self, + x: torch.Tensor, + supervisions: Optional[Supervisions] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute encoder padding mask, which is used as memory key padding + mask for the decoder. + warmup: + A floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. + Returns: + Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). + Tensor: Mask tensor of dimension (batch_size, input_length) + """ + x = self.encoder_embed(x) + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + mask = encoder_padding_mask(x.size(0), self.subsampling_factor, supervisions) + if mask is not None: + mask = mask.to(x.device) + + # Caution: We assume the subsampling factor is 4! + + x = self.encoder( + x, pos_emb, src_key_padding_mask=mask, warmup=warmup + ) # (T, N, C) + + # x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + # return x, lengths + return x, mask + + +class ConformerEncoderLayer(nn.Module): + """ + ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. + See: "Conformer: Convolution-augmented Transformer for Speech Recognition" + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + layer_dropout: float = 0.075, + cnn_module_kernel: int = 31, + ) -> None: + super(ConformerEncoderLayer, self).__init__() + + self.layer_dropout = layer_dropout + + self.d_model = d_model + + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + + self.feed_forward = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + + self.feed_forward_macaron = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + + self.norm_final = BasicNorm(d_model) + + # try to ensure the output is close to zero-mean (or at least, zero-median). + self.balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 + ) + + self.dropout = nn.Dropout(dropout) + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + warmup: float = 1.0, + ) -> Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + warmup: controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + src_orig = src + + warmup_scale = min(0.1 + warmup, 1.0) + # alpha = 1.0 means fully use this encoder layer, 0.0 would mean + # completely bypass it. + if self.training: + alpha = ( + warmup_scale + if torch.rand(()).item() <= (1.0 - self.layer_dropout) + else 0.1 + ) + else: + alpha = 1.0 + + # macaron style feed forward module + src = src + self.dropout(self.feed_forward_macaron(src)) + + # multi-headed self-attention module + src_att = self.self_attn( + src, + src, + src, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + src = src + self.dropout(src_att) + + # convolution module + src = src + self.dropout( + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) + ) + + # feed forward module + src = src + self.dropout(self.feed_forward(src)) + + src = self.norm_final(self.balancer(src)) + + if alpha != 1.0: + src = alpha * src + (1 - alpha) * src_orig + + return src + + +class ConformerEncoder(nn.Module): + r"""ConformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the ConformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = conformer_encoder(src, pos_emb) + """ + + def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: + super().__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + warmup: float = 1.0, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + pos_emb: Positional embedding tensor (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + """ + output = src + + for i, mod in enumerate(self.layers): + output = mod( + output, + pos_emb, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + warmup=warmup, + ) + + return output + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: Tensor) -> None: + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). + + """ + self.extend_pe(x) + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x.size(1) + + 1 : self.pe.size(1) // 2 # noqa E203 + + x.size(1), + ] + return self.dropout(x), self.dropout(pos_emb) + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + + Examples:: + + >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = ScaledLinear( + embed_dim, embed_dim, bias=True, initial_scale=0.25 + ) + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) + self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) + self._reset_parameters() + + def _pos_bias_u(self): + return self.pos_bias_u * self.pos_bias_u_scale.exp() + + def _pos_bias_v(self): + return self.pos_bias_v * self.pos_bias_v_scale.exp() + + def _reset_parameters(self) -> None: + nn.init.normal_(self.pos_bias_u, std=0.01) + nn.init.normal_(self.pos_bias_v, std=0.01) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + return self.multi_head_attention_forward( + query, + key, + value, + pos_emb, + self.embed_dim, + self.num_heads, + self.in_proj.get_weight(), + self.in_proj.get_bias(), + self.dropout, + self.out_proj.get_weight(), + self.out_proj.get_bias(), + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + ) + + def rel_shift(self, x: Tensor) -> Tensor: + """Compute relative positional encoding. + + Args: + x: Input tensor (batch, head, time1, 2*time1-1). + time1 means the length of query vector. + + Returns: + Tensor: tensor of shape (batch, head, time1, time2) + (note: time2 has the same value as time1, but it is for + the key, while time1 is for the query). + """ + (batch_size, num_heads, time1, n) = x.shape + assert n == 2 * time1 - 1 + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time1_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, time1, time1), + (batch_stride, head_stride, time1_stride - n_stride, n_stride), + storage_offset=n_stride * (time1 - 1), + ) + + def multi_head_attention_forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence + length, N is the batch size, E is the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == embed_dim_to_check + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = embed_dim // num_heads + assert ( + head_dim * num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + + scaling = float(head_dim) ** -0.5 + + if torch.equal(query, key) and torch.equal(key, value): + # self-attention + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) + + elif torch.equal(key, value): + # encoder-decoder attention + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) + + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = embed_dim * 2 + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = nn.functional.linear(key, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim * 2 + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + v = nn.functional.linear(value, _w, _b) + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( + attn_mask.dtype + ) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError("The size of the 2D attn_mask is not correct.") + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + query.size(0), + key.size(0), + ]: + raise RuntimeError("The size of the 3D attn_mask is not correct.") + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + ) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim) + k = k.contiguous().view(-1, bsz, num_heads, head_dim) + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + + src_len = k.size(0) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz, "{} == {}".format( + key_padding_mask.size(0), bsz + ) + assert key_padding_mask.size(1) == src_len, "{} == {}".format( + key_padding_mask.size(1), src_len + ) + + q = q.transpose(0, 1) # (batch, time1, head, d_k) + + pos_emb_bsz = pos_emb.size(0) + assert pos_emb_bsz in (1, bsz) # actually it is 1 + p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) + p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) + + q_with_bias_u = (q + self._pos_bias_u()).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + q_with_bias_v = (q + self._pos_bias_v()).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + # compute attention score + # first compute matrix a and matrix c + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + + # compute matrix b and matrix d + matrix_bd = torch.matmul( + q_with_bias_v, p.transpose(-2, -1) + ) # (batch, head, time1, 2*time1-1) + matrix_bd = self.rel_shift(matrix_bd) + + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) + + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + + assert list(attn_output_weights.size()) == [ + bsz * num_heads, + tgt_len, + src_len, + ] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float("-inf")) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, src_len + ) + + attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=dropout_p, training=training + ) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + attn_output = ( + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, None + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.pointwise_conv1 = ScaledConv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + + # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, + # but sometimes, for some reason, for layer 0 the rms ends up being very large, + # between 50 and 100 for different channels. This will cause very peaky and + # sparse derivatives for the sigmoid gating function, which will tend to make + # the loss function not learn effectively. (for most layers the average absolute values + # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different + # layers, which likely breaks down as 0.5 for the "linear" half and + # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we + # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, + # it will be in a better position to start learning something, i.e. to latch onto + # the correct range. + self.deriv_balancer1 = ActivationBalancer( + channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 + ) + + self.depthwise_conv = ScaledConv1d( + channels, + channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + + self.deriv_balancer2 = ActivationBalancer( + channel_dim=1, min_positive=0.05, max_positive=1.0 + ) + + self.activation = DoubleSwish() + + self.pointwise_conv2 = ScaledConv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + initial_scale=0.25, + ) + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + + x = self.deriv_balancer1(x) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # 1D Depthwise Conv + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + x = self.depthwise_conv(x) + + x = self.deriv_balancer2(x) + x = self.activation(x) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) + + +if __name__ == "__main__": + feature_dim = 50 + c = Conformer(num_features=feature_dim, d_model=128, nhead=4) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = c( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + warmup=0.5, + ) diff --git a/egs/librispeech/WSASR/conformer_ctc2/decode.py b/egs/librispeech/WSASR/conformer_ctc2/decode.py new file mode 100755 index 000000000..3fa045533 --- /dev/null +++ b/egs/librispeech/WSASR/conformer_ctc2/decode.py @@ -0,0 +1,718 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, +# Fangjun Kuang, +# Quandong Wang) +# 2023 Johns Hopkins University (Author: Dongji Gao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import get_lattice, one_best_decoding +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.otc_graph_compiler import OtcTrainingGraphCompiler +from icefall.utils import ( + AttributeDict, + get_texts, + load_averaged_model, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--otc-token", + type=str, + default="", + help="OTC token", + ) + + parser.add_argument( + "--blank-bias", + type=float, + default=0, + help="bias (log-prob) added to blank token during decoding", + ) + + parser.add_argument( + "--epoch", + type=int, + default=20, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--method", + type=str, + default="ctc-greedy-search", + help="""Decoding method. + Supported values are: + - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (1) ctc-greedy-search. It only use CTC output and a sentence piece + model for decoding. It produces the same results with ctc-decoding. + - (2) 1best. Extract the best path from the decoding lattice as the + decoding result. + """, + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--num-decoder-layers", + type=int, + default=0, + help="""Number of decoder layer of transformer decoder. + Setting this to 0 will not create the decoder at all (pure CTC model) + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc2/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_200", + help="The lang dir", + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The n-gram LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + # parameters for conformer + "subsampling_factor": 2, + "feature_dim": 768, + "nhead": 8, + "dim_feedforward": 2048, + "encoder_dim": 512, + "num_encoder_layers": 12, + # parameters for decoding + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + "env_info": get_env_info(), + } + ) + return params + + +def ctc_greedy_search( + nnet_output: torch.Tensor, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, +) -> List[List[int]]: + """Apply CTC greedy search + + Args: + speech (torch.Tensor): (batch, max_len, feat_dim) + speech_length (torch.Tensor): (batch, ) + Returns: + List[List[int]]: best path result + """ + batch_size = memory.shape[1] + # Let's assume B = batch_size + encoder_out = memory + encoder_mask = memory_key_padding_mask + maxlen = encoder_out.size(0) + + ctc_probs = nnet_output # (B, maxlen, vocab_size) + topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1) + topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen) + topk_index = topk_index.masked_fill_(encoder_mask, 0) # (B, maxlen) + hyps = [hyp.tolist() for hyp in topk_index] + scores = topk_prob.max(1) + hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps] + return hyps, scores + + +def remove_duplicates_and_blank(hyp: List[int]) -> List[int]: + # from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py + new_hyp: List[int] = [] + cur = 0 + while cur < len(hyp): + if hyp[cur] != 0: + new_hyp.append(hyp[cur]) + prev = cur + while cur < len(hyp) and hyp[cur] == hyp[prev]: + cur += 1 + return new_hyp + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + + - params.method is "1best", it uses 1best decoding without LM rescoring. + + model: + The neural model. + HLG: + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + sos_id: + The token ID of the SOS. + eos_id: + The token ID of the EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + if HLG is not None: + device = HLG.device + else: + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + + nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) + # nnet_output is (N, T, C) + nnet_output[:, :, 0] += params.blank_bias + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + torch.div( + supervisions["start_frame"], + params.subsampling_factor, + rounding_mode="trunc", + ), + torch.div( + supervisions["num_frames"], + params.subsampling_factor, + rounding_mode="trunc", + ), + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor + 2, + ) + + if params.method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + + if params.method == "ctc-greedy-search": + hyps, _ = ctc_greedy_search( + nnet_output, + memory, + memory_key_padding_mask, + ) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(hyps) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-greedy-search" + return {key: hyps} + + if params.method in ["1best"]: + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no_rescore" + + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + + return {key: hyps} + else: + assert False, f"Unsupported decoding method: {params.method}" + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + word_table: + It is the word symbol table. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + if hyps_dict is not None: + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[lm_scale].extend(this_batch) + else: + assert len(results) > 0, "It should not decode to empty in the first batch!" + this_batch = [] + hyp_words = [] + for ref_text in texts: + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + for lm_scale in results.keys(): + results[lm_scale].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + if params.method in ("attention-decoder", "rnn-lm"): + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=enable_log + ) + test_set_wers[key] = wer + + if enable_log: + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + assert "▁" not in args.otc_token + args.otc_token = f"▁{args.otc_token}" + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") + logging.info("Decoding started") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + # remove otc_token from decoding units + max_token_id = max(lexicon.tokens) - 1 + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + graph_compiler = OtcTrainingGraphCompiler( + params.lang_dir, + params.otc_token, + device=device, + sos_token="", + eos_token="", + ) + sos_id = graph_compiler.sos_id + eos_id = graph_compiler.eos_id + + params.num_classes = num_classes + params.sos_id = sos_id + params.eos_id = eos_id + + if params.method == "ctc-decoding" or params.method == "ctc-greedy-search": + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + G = None + + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.encoder_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_encoder_layers=params.num_encoder_layers, + num_decoder_layers=params.num_decoder_layers, + ) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + save_results(params=params, test_set_name=test_set, results_dict=results_dict) + + logging.info("Done!") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/WSASR/conformer_ctc2/export.py b/egs/librispeech/WSASR/conformer_ctc2/export.py new file mode 120000 index 000000000..5f484e391 --- /dev/null +++ b/egs/librispeech/WSASR/conformer_ctc2/export.py @@ -0,0 +1 @@ +../../ASR/conformer_ctc2/export.py \ No newline at end of file diff --git a/egs/librispeech/WSASR/conformer_ctc2/label_smoothing.py b/egs/librispeech/WSASR/conformer_ctc2/label_smoothing.py new file mode 120000 index 000000000..c050ea637 --- /dev/null +++ b/egs/librispeech/WSASR/conformer_ctc2/label_smoothing.py @@ -0,0 +1 @@ +../../ASR/conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/librispeech/WSASR/conformer_ctc2/optim.py b/egs/librispeech/WSASR/conformer_ctc2/optim.py new file mode 120000 index 000000000..db836b5e0 --- /dev/null +++ b/egs/librispeech/WSASR/conformer_ctc2/optim.py @@ -0,0 +1 @@ +../../ASR/pruned_transducer_stateless2/optim.py \ No newline at end of file diff --git a/egs/librispeech/WSASR/conformer_ctc2/scaling.py b/egs/librispeech/WSASR/conformer_ctc2/scaling.py new file mode 120000 index 000000000..bd0abfeee --- /dev/null +++ b/egs/librispeech/WSASR/conformer_ctc2/scaling.py @@ -0,0 +1 @@ +../../ASR/pruned_transducer_stateless2/scaling.py \ No newline at end of file diff --git a/egs/librispeech/WSASR/conformer_ctc2/subsampling.py b/egs/librispeech/WSASR/conformer_ctc2/subsampling.py new file mode 100644 index 000000000..2ba802866 --- /dev/null +++ b/egs/librispeech/WSASR/conformer_ctc2/subsampling.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# 2022 Xiaomi Corporation (author: Quandong Wang) +# 2023 Johns Hopkins University (author: Dongji Gao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from scaling import ( + ActivationBalancer, + BasicNorm, + DoubleSwish, + ScaledConv2d, + ScaledLinear, +) + + +class Conv2dSubsampling(torch.nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + ) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >=7, in_channels >=7 + out_channels + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer1_channels: + Number of channels in layer2 + """ + assert in_channels >= 7 + super().__init__() + + self.conv = torch.nn.Sequential( + ScaledConv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=1, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ) + self.out = ScaledLinear( + layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels + ) + # set learn_eps=False because out_norm is preceded by `out`, and `out` + # itself has learned scale, so the extra degree of freedom is not + # needed. + self.out_norm = BasicNorm(out_channels, learn_eps=False) + # constrain median of output to be close to zero. + self.out_balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) + x = self.out_norm(x) + x = self.out_balancer(x) + return x + + +class Conv2dSubsampling2(torch.nn.Module): + """Convolutional 2D subsampling (to 1/2 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim) where + T' = (T - 1) // 2 - 2, which approximates T' == T // 2 + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + ) -> None: + assert in_channels >= 7 + super().__init__() + + self.conv = torch.nn.Sequential( + ScaledConv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=1, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=1, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ) + self.out = ScaledLinear( + layer3_channels * ((in_channels - 1) // 2 - 2), out_channels + ) + self.out_norm = BasicNorm(out_channels, learn_eps=False) + self.out_balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.unsqueeze(1) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + x = self.out_norm(x) + x = self.out_balancer(x) + return x diff --git a/egs/librispeech/WSASR/conformer_ctc2/train.py b/egs/librispeech/WSASR/conformer_ctc2/train.py new file mode 100755 index 000000000..fe6c5af91 --- /dev/null +++ b/egs/librispeech/WSASR/conformer_ctc2/train.py @@ -0,0 +1,1115 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Quandong Wang) +# 2023 Johns Hopkins University (author: Dongji Gao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./conformer_ctc2/train.py \ + --world-size 4 \ + --manifest-dir data/ssl \ + --train-manifest librispeech_cuts_train-clean-100_0.17_0.17_0.17.jsonl.gz \ + --exp-dir conformer_ctc2/exp \ + --lang-dir data/lang_bpe_200 \ + --otc-token "" \ + --allow-bypass-arc true \ + --allow-self-loop-arc true \ + --initial-bypass-weight -19 \ + --initial-self-loop-weight 3.75 \ + --bypass-weight-decay 0.975 \ + --self-loop-weight-decay 0.999 \ + --show-alignment true +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from optim import Eden, Eve +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.decode import one_best_decoding +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.otc_graph_compiler import OtcTrainingGraphCompiler +from icefall.utils import ( + AttributeDict, + MetricsTracker, + encode_supervisions_otc, + get_texts, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=20, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc2/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_200", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--initial-lr", + type=float, + default=0.003, + help="""The initial learning rate. This value should not need to be + changed.""", + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate decreases. + We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=6, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--att-rate", + type=float, + default=0.0, + help="""The attention rate. + The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss + """, + ) + + parser.add_argument( + "--num-decoder-layers", + type=int, + default=0, + help="""Number of decoder layer of transformer decoder. + Setting this to 0 will not create the decoder at all (pure CTC model) + """, + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=8000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=10, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=100, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--otc-token", + type=str, + default="_", + help="OTC token", + ) + + parser.add_argument( + "--allow-bypass-arc", + type=str2bool, + default=True, + help="""Whether to add bypass arc to training graph for substitution + and insertion errors (wrong or extra words in the transcript).""", + ) + + parser.add_argument( + "--allow-self-loop-arc", + type=str2bool, + default=True, + help="""Whether to self-loop bypass arc to training graph for deletion errors + (missing words in the transcript).""", + ) + + parser.add_argument( + "--initial-bypass-weight", + type=float, + default=0.0, + help="Initial weight associated with bypass arc", + ) + + parser.add_argument( + "--initial-self-loop-weight", + type=float, + default=0.0, + help="Initial weight associated with self-loop arc", + ) + + parser.add_argument( + "--bypass-weight-decay", + type=float, + default=1.0, + help="""Weight decay factor of bypass arc weight: + bypass_arc_weight = intial_bypass_weight * bypass_weight_decay ^ ith-epoch""", + ) + + parser.add_argument( + "--self-loop-weight-decay", + type=float, + default=1.0, + help="""Weight decay factor of self-loop arc weight: + self_loop_arc_weight = intial_self_loop_weight * self_loop_weight_decay ^ ith-epoch""", + ) + + parser.add_argument( + "--show-alignment", + type=str2bool, + default=True, + help="Whether to print OTC alignment during training", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - beam_size: It is used in k2.ctc_loss + + - reduction: It is used in k2.ctc_loss + + - use_double_scores: It is used in k2.ctc_loss + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 1, + "reset_interval": 200, + "valid_interval": 800, # For the 100h subset, use 800 + "alignment_interval": 25, + # parameters for conformer + "feature_dim": 768, + "subsampling_factor": 2, + "encoder_dim": 512, + "nhead": 8, + "dim_feedforward": 2048, + "num_encoder_layers": 12, + # parameters for ctc loss + "beam_size": 10, + "reduction": "sum", + "use_double_scores": True, + # parameters for Noam + "model_warm_step": 3000, # arg given to model, not for lrate + "env_info": get_env_info(), + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + batch: dict, + graph_compiler: OtcTrainingGraphCompiler, + is_training: bool, + warmup: float = 2.0, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute OTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + graph_compiler: + It is used to build a decoding graph from a ctc topo and training + transcript. The training transcript is contained in the given `batch`, + while the ctc topo is built when this compiler is instantiated. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + with torch.set_grad_enabled(is_training): + nnet_output, encoder_memory, memory_mask = model( + feature, supervisions, warmup=warmup + ) + # Set the probability of OTC token as the average of non-blank tokens + # under the assumption that blank is the first and + # OTC token is the last token in tokens.txt + _, _, V = nnet_output.shape + + otc_token_log_prob = torch.logsumexp( + nnet_output[:, :, 1:], dim=-1, keepdim=True + ) - torch.log(torch.tensor([V - 1])).to(device) + + nnet_output = torch.cat([nnet_output, otc_token_log_prob], dim=-1) + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + supervision_segments, texts, utt_ids, verbatim_texts = encode_supervisions_otc( + supervisions, subsampling_factor=params.subsampling_factor + ) + + bypass_weight = graph_compiler.initial_bypass_weight * ( + graph_compiler.bypass_weight_decay ** (params.cur_epoch - 1) + ) + self_loop_weight = graph_compiler.initial_self_loop_weight * ( + graph_compiler.self_loop_weight_decay ** (params.cur_epoch - 1) + ) + + decoding_graph = graph_compiler.compile( + texts=texts, + allow_bypass_arc=params.allow_bypass_arc, + allow_self_loop_arc=params.allow_self_loop_arc, + bypass_weight=bypass_weight, + self_loop_weight=self_loop_weight, + ) + + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=3, + ) + + otc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + + assert params.att_rate == 0.0 + loss = otc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["otc_loss"] = otc_loss.detach().cpu().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + + # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa + info["utterances"] = feature.size(0) + # averaged input duration in frames over utterances + info["utt_duration"] = feature_lens.sum().item() + # averaged padding proportion over utterances + info["utt_pad_proportion"] = ( + ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() + ) + + if params.show_alignment: + if params.batch_idx_train % params.alignment_interval == 0: + for index, utt_id in enumerate(utt_ids): + verbatim_text = verbatim_texts[index] + utt_id = utt_ids[index] + + lattice = k2.intersect_dense( + decoding_graph, + dense_fsa_vec, + params.beam_size, + ) + best_path = one_best_decoding( + lattice=lattice, + use_double_scores=params.use_double_scores, + ) + hyp_ids = get_texts(best_path)[index] + hyp_text_list = [graph_compiler.token_table[i] for i in hyp_ids] + hyp_text = "".join(hyp_text_list).replace("▁", " ") + + logging.info(f"[utterance id]: {utt_id}") + logging.info(f"[verbatim text]: {verbatim_text}") + logging.info(f"[best alignment]: {hyp_text}") + logging.info(bypass_weight) + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: OtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + graph_compiler: OtcTrainingGraphCompiler, + scheduler: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + graph_compiler: + It is used to convert transcripts to FSAs. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + # scaler.scale(loss).backward() + + try: + # loss.backward() + scaler.scale(loss).backward() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error(f"failing batch size:{batch_size} ") + raise + + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + if params.print_diagnostics and batch_idx == 30: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" + ) + if loss_info["otc_loss"] == float("inf"): + logging.error("Your loss contains inf, something goes wrong") + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + params.valid_interval = 1600 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + graph_compiler = OtcTrainingGraphCompiler( + params.lang_dir, + otc_token=params.otc_token, + device=device, + initial_bypass_weight=params.initial_bypass_weight, + initial_self_loop_weight=params.initial_self_loop_weight, + bypass_weight_decay=params.bypass_weight_decay, + self_loop_weight_decay=params.self_loop_weight_decay, + ) + + # remove OTC token as it is the average of all non-blank tokens + max_token_id = graph_compiler.get_max_token_id() - 1 + # add blank + num_classes = max_token_id + 1 + + logging.info("About to create model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.encoder_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_encoder_layers=params.num_encoder_layers, + num_decoder_layers=params.num_decoder_layers, + ) + + print(model) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank]) + + optimizer = Eve(model.parameters(), lr=params.initial_lr) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + diagnostic = diagnostics.attach_diagnostics(model) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 20.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + graph_compiler=graph_compiler, + scheduler=scheduler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: OtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + # warmup = 0.0 is so that the derivs for the pruned loss stay zero + # (i.e. are not remembered by the decaying-average in adam), because + # we want to avoid these params being subject to shrinkage in adam. + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + warmup=0.0, + ) + loss.backward() + optimizer.step() + optimizer.zero_grad() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + assert "▁" not in args.otc_token + args.otc_token = f"▁{args.otc_token}" + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/WSASR/conformer_ctc2/transformer.py b/egs/librispeech/WSASR/conformer_ctc2/transformer.py new file mode 100644 index 000000000..41e6cd357 --- /dev/null +++ b/egs/librispeech/WSASR/conformer_ctc2/transformer.py @@ -0,0 +1,1055 @@ +# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# Copyright 2022 Xiaomi Corp. (author: Quandong Wang) +# 2023 Johns Hopkins University (author: Dongji Gao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from attention import MultiheadAttention +from label_smoothing import LabelSmoothingLoss +from scaling import ( + ActivationBalancer, + BasicNorm, + DoubleSwish, + ScaledEmbedding, + ScaledLinear, +) +from subsampling import Conv2dSubsampling +from torch.nn.utils.rnn import pad_sequence + +# Note: TorchScript requires Dict/List/etc. to be fully typed. +Supervisions = Dict[str, torch.Tensor] + + +class Transformer(nn.Module): + def __init__( + self, + num_features: int, + num_classes: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + num_decoder_layers: int = 6, + dropout: float = 0.1, + layer_dropout: float = 0.075, + ) -> None: + """ + Args: + num_features: + The input dimension of the model. + num_classes: + The output dimension of the model. + subsampling_factor: + Number of output frames is num_in_frames // subsampling_factor. + Currently, subsampling_factor MUST be 4. + d_model: + Attention dimension. + nhead: + Number of heads in multi-head attention. + Must satisfy d_model // nhead == 0. + dim_feedforward: + The output dimension of the feedforward layers in encoder/decoder. + num_encoder_layers: + Number of encoder layers. + num_decoder_layers: + Number of decoder layers. + dropout: + Dropout in encoder/decoder. + layer_dropout (float): layer-dropout rate. + """ + super().__init__() + + self.num_features = num_features + self.num_classes = num_classes + self.subsampling_factor = subsampling_factor + if subsampling_factor != 4 and subsampling_factor != 2: + raise NotImplementedError("Support only 'subsampling_factor=4 or 2'.") + + # self.encoder_embed converts the input of shape (N, T, num_classes) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_classes -> d_model + self.encoder_embed = Conv2dSubsampling(num_features, d_model) + + self.encoder_pos = PositionalEncoding(d_model, dropout) + + encoder_layer = TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + layer_dropout=layer_dropout, + ) + + self.encoder = TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=num_encoder_layers, + ) + + # TODO(fangjun): remove dropout + self.encoder_output_layer = nn.Sequential( + nn.Dropout(p=dropout), ScaledLinear(d_model, num_classes, bias=True) + ) + + if num_decoder_layers > 0: + self.decoder_num_class = ( + self.num_classes + ) # bpe model already has sos/eos symbol + + self.decoder_embed = ScaledEmbedding( + num_embeddings=self.decoder_num_class, embedding_dim=d_model + ) + self.decoder_pos = PositionalEncoding(d_model, dropout) + + decoder_layer = TransformerDecoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + ) + + self.decoder = TransformerDecoder( + decoder_layer=decoder_layer, + num_layers=num_decoder_layers, + ) + + self.decoder_output_layer = ScaledLinear( + d_model, self.decoder_num_class, bias=True + ) + + self.decoder_criterion = LabelSmoothingLoss() + else: + self.decoder_criterion = None + + def forward( + self, + x: torch.Tensor, + supervision: Optional[Supervisions] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (N, T, C). + supervision: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + (CAUTION: It contains length information, i.e., start and number of + frames, before subsampling) + warmup: + A floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. + + Returns: + Return a tuple containing 3 tensors: + - CTC output for ctc decoding. Its shape is (N, T, C) + - Encoder output with shape (T, N, C). It can be used as key and + value for the decoder. + - Encoder output padding mask. It can be used as + memory_key_padding_mask for the decoder. Its shape is (N, T). + It is None if `supervision` is None. + """ + + encoder_memory, memory_key_padding_mask = self.run_encoder( + x, supervision, warmup + ) + + x = self.ctc_output(encoder_memory) + return x, encoder_memory, memory_key_padding_mask + + def run_encoder( + self, + x: torch.Tensor, + supervisions: Optional[Supervisions] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Run the transformer encoder. + + Args: + x: + The model input. Its shape is (N, T, C). + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute the encoder padding mask, which is used as memory key + padding mask for the decoder. + Returns: + Return a tuple with two tensors: + - The encoder output, with shape (T, N, C) + - encoder padding mask, with shape (N, T). + The mask is None if `supervisions` is None. + It is used as memory key padding mask in the decoder. + """ + x = self.encoder_embed(x) + x = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + mask = encoder_padding_mask(x.size(0), supervisions) + mask = mask.to(x.device) if mask is not None else None + x = self.encoder(x, src_key_padding_mask=mask, warmup=warmup) # (T, N, C) + + return x, mask + + def ctc_output(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + The output tensor from the transformer encoder. + Its shape is (T, N, C) + + Returns: + Return a tensor that can be used for CTC decoding. + Its shape is (N, T, C) + """ + x = self.encoder_output_layer(x) + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) + return x + + @torch.jit.export + def decoder_forward( + self, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + token_ids: List[List[int]], + sos_id: int, + eos_id: int, + ) -> torch.Tensor: + """ + Args: + memory: + It's the output of the encoder with shape (T, N, C) + memory_key_padding_mask: + The padding mask from the encoder. + token_ids: + A list-of-list IDs. Each sublist contains IDs for an utterance. + The IDs can be either phone IDs or word piece IDs. + sos_id: + sos token id + eos_id: + eos token id + + Returns: + A scalar, the **sum** of label smoothing loss over utterances + in the batch without any normalization. + """ + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + + device = memory.device + ys_in_pad = ys_in_pad.to(device) + ys_out_pad = ys_out_pad.to(device) + + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. + tgt_key_padding_mask[:, 0] = False + + tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) + tgt = self.decoder_pos(tgt) + tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + pred_pad = self.decoder( + tgt=tgt, + memory=memory, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + ) # (T, N, C) + pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) + + decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad) + + return decoder_loss + + @torch.jit.export + def decoder_nll( + self, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + token_ids: List[torch.Tensor], + sos_id: int, + eos_id: int, + ) -> torch.Tensor: + """ + Args: + memory: + It's the output of the encoder with shape (T, N, C) + memory_key_padding_mask: + The padding mask from the encoder. + token_ids: + A list-of-list IDs (e.g., word piece IDs). + Each sublist represents an utterance. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + Returns: + A 2-D tensor of shape (len(token_ids), max_token_length) + representing the cross entropy loss (i.e., negative log-likelihood). + """ + # The common part between this function and decoder_forward could be + # extracted as a separate function. + if isinstance(token_ids[0], torch.Tensor): + # This branch is executed by torchscript in C++. + # See https://github.com/k2-fsa/k2/pull/870 + # https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286 + token_ids = [tolist(t) for t in token_ids] + + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + + device = memory.device + ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) + ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) + + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. + tgt_key_padding_mask[:, 0] = False + + tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) + tgt = self.decoder_pos(tgt) + tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + pred_pad = self.decoder( + tgt=tgt, + memory=memory, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + ) # (T, B, F) + pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F) + pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F) + # nll: negative log-likelihood + nll = torch.nn.functional.cross_entropy( + pred_pad.view(-1, self.decoder_num_class), + ys_out_pad.view(-1), + ignore_index=-1, + reduction="none", + ) + + nll = nll.view(pred_pad.shape[0], -1) + + return nll + + +class TransformerEncoderLayer(nn.Module): + """ + Modified from torch.nn.TransformerEncoderLayer. + + Args: + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + activation: + the activation function of intermediate layer, relu or + gelu (default=relu). + + Examples:: + >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> out = encoder_layer(src) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + layer_dropout: float = 0.075, + ) -> None: + super(TransformerEncoderLayer, self).__init__() + + self.layer_dropout = layer_dropout + + self.self_attn = MultiheadAttention(d_model, nhead, dropout=0.0) + # Implementation of Feedforward model + + self.feed_forward = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + + self.norm_final = BasicNorm(d_model) + + # try to ensure the output is close to zero-mean (or at least, zero-median). + self.balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 + ) + + self.dropout = nn.Dropout(dropout) + + def forward( + self, + src: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None, + warmup: float = 1.0, + ) -> torch.Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional) + warmup: controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. + + Shape: + src: (S, N, E). + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, + N is the batch size, E is the feature number + """ + src_orig = src + + warmup_scale = min(0.1 + warmup, 1.0) + # alpha = 1.0 means fully use this encoder layer, 0.0 would mean + # completely bypass it. + if self.training: + alpha = ( + warmup_scale + if torch.rand(()).item() <= (1.0 - self.layer_dropout) + else 0.1 + ) + else: + alpha = 1.0 + + # src_att = self.self_attn(src, src, src, src_mask) + src_att = self.self_attn( + src, + src, + src, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + src = src + self.dropout(src_att) + + src = src + self.dropout(self.feed_forward(src)) + + src = self.norm_final(self.balancer(src)) + + if alpha != 1.0: + src = alpha * src + (1 - alpha) * src_orig + + return src + + +class TransformerDecoderLayer(nn.Module): + """ + Modified from torch.nn.TransformerDecoderLayer. + Add support of normalize_before, + i.e., use layer_norm before the first block. + + Args: + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + activation: + the activation function of intermediate layer, relu or + gelu (default=relu). + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = decoder_layer(tgt, memory) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + layer_dropout: float = 0.075, + normalize_before: bool = True, + ) -> None: + super(TransformerDecoderLayer, self).__init__() + self.layer_dropout = layer_dropout + self.self_attn = MultiheadAttention(d_model, nhead, dropout=0.0) + self.src_attn = MultiheadAttention(d_model, nhead, dropout=0.0) + # Implementation of Feedforward model + self.feed_forward = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + + self.norm_final = BasicNorm(d_model) + + # try to ensure the output is close to zero-mean (or at least, zero-median). + self.balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 + ) + + self.dropout = nn.Dropout(dropout) + + def forward( + self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None, + warmup: float = 1.0, + ) -> torch.Tensor: + """Pass the inputs (and mask) through the decoder layer. + + Args: + tgt: + the sequence to the decoder layer (required). + memory: + the sequence from the last layer of the encoder (required). + tgt_mask: + the mask for the tgt sequence (optional). + memory_mask: + the mask for the memory sequence (optional). + tgt_key_padding_mask: + the mask for the tgt keys per batch (optional). + memory_key_padding_mask: + the mask for the memory keys per batch (optional). + warmup: controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. + + + + Shape: + tgt: (T, N, E). + memory: (S, N, E). + tgt_mask: (T, T). + memory_mask: (T, S). + tgt_key_padding_mask: (N, T). + memory_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, + N is the batch size, E is the feature number + """ + tgt_orig = tgt + + warmup_scale = min(0.1 + warmup, 1.0) + # alpha = 1.0 means fully use this encoder layer, 0.0 would mean + # completely bypass it. + if self.training: + alpha = ( + warmup_scale + if torch.rand(()).item() <= (1.0 - self.layer_dropout) + else 0.1 + ) + else: + alpha = 1.0 + + # tgt_att = self.self_attn(tgt, tgt, tgt, tgt_mask) + tgt_att = self.self_attn( + tgt, + tgt, + tgt, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask, + )[0] + tgt = tgt + self.dropout(tgt_att) + + # src_att = self.src_attn(tgt, memory, memory, memory_mask) + src_att = self.src_attn( + tgt, + memory, + memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout(src_att) + + tgt = tgt + self.dropout(self.feed_forward(tgt)) + + tgt = self.norm_final(self.balancer(tgt)) + + if alpha != 1.0: + tgt = alpha * tgt + (1 - alpha) * tgt_orig + + return tgt + + +class TransformerEncoder(nn.Module): + r"""TransformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the TransformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + + Examples:: + >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) + >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = transformer_encoder(src) + """ + + def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: + super().__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + def forward( + self, + src: torch.Tensor, + mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None, + warmup: float = 1.0, + ) -> torch.Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + """ + output = src + + for mod in self.layers: + output = mod( + output, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + warmup=warmup, + ) + + return output + + +class TransformerDecoder(nn.Module): + r"""TransformerDecoder is a stack of N decoder layers + + Args: + decoder_layer: an instance of the TransformerDecoderLayer() class (required). + num_layers: the number of sub-decoder-layers in the decoder (required). + + Examples:: + >>> decoder_layer = TransformerDecoderLayer(d_model=512, nhead=8) + >>> transformer_decoder = TransformerDecoder(decoder_layer, num_layers=6) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(10, 32, 512) + >>> out = transformer_decoder(tgt, memory) + """ + + def __init__(self, decoder_layer: nn.Module, num_layers: int) -> None: + super().__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(decoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + def forward( + self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None, + warmup: float = 1.0, + ) -> torch.Tensor: + r"""Pass the input through the decoder layers in turn. + + Args: + tgt: the sequence to the decoder (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + + Shape: + tgt: (S, N, E). + tgt_mask: (S, S). + tgt_key_padding_mask: (N, S). + + """ + output = tgt + + for mod in self.layers: + output = mod( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + warmup=warmup, + ) + + return output + + +class PositionalEncoding(nn.Module): + """This class implements the positional encoding + proposed in the following paper: + + - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf + + PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) + PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) + + Note:: + + 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) + = exp(-1* 2i / d_model * log(100000)) + = exp(2i * -(log(10000) / d_model)) + """ + + def __init__(self, d_model: int, dropout: float = 0.1) -> None: + """ + Args: + d_model: + Embedding dimension. + dropout: + Dropout probability to be applied to the output of this module. + """ + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = nn.Dropout(p=dropout) + # not doing: self.pe = None because of errors thrown by torchscript + self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32) + + def extend_pe(self, x: torch.Tensor) -> None: + """Extend the time t in the positional encoding if required. + + The shape of `self.pe` is (1, T1, d_model). The shape of the input x + is (N, T, d_model). If T > T1, then we change the shape of self.pe + to (N, T, d_model). Otherwise, nothing is done. + + Args: + x: + It is a tensor of shape (N, T, C). + Returns: + Return None. + """ + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + # Now pe is of shape (1, T, d_model), where T is x.size(1) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Add positional encoding. + + Args: + x: + Its shape is (N, T, C) + + Returns: + Return a tensor of shape (N, T, C) + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1), :] + return self.dropout(x) + + +class Noam(object): + """ + Implements Noam optimizer. + + Proposed in + "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf + + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa + + Args: + params: + iterable of parameters to optimize or dicts defining parameter groups + model_size: + attention dimension of the transformer model + factor: + learning rate factor + warm_step: + warmup steps + """ + + def __init__( + self, + params, + model_size: int = 256, + factor: float = 10.0, + warm_step: int = 25000, + weight_decay=0, + ) -> None: + """Construct an Noam object.""" + self.optimizer = torch.optim.Adam( + params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay + ) + self._step = 0 + self.warmup = warm_step + self.factor = factor + self.model_size = model_size + self._rate = 0 + + @property + def param_groups(self): + """Return param_groups.""" + return self.optimizer.param_groups + + def step(self): + """Update parameters and rate.""" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step=None): + """Implement `lrate` above.""" + if step is None: + step = self._step + return ( + self.factor + * self.model_size ** (-0.5) + * min(step ** (-0.5), step * self.warmup ** (-1.5)) + ) + + def zero_grad(self): + """Reset gradient.""" + self.optimizer.zero_grad() + + def state_dict(self): + """Return state_dict.""" + return { + "_step": self._step, + "warmup": self.warmup, + "factor": self.factor, + "model_size": self.model_size, + "_rate": self._rate, + "optimizer": self.optimizer.state_dict(), + } + + def load_state_dict(self, state_dict): + """Load state_dict.""" + for key, value in state_dict.items(): + if key == "optimizer": + self.optimizer.load_state_dict(state_dict["optimizer"]) + else: + setattr(self, key, value) + + +def encoder_padding_mask( + max_len: int, + subsampling_factor: Optional[int] = 4, + supervisions: Optional[Supervisions] = None, +) -> Optional[torch.Tensor]: + """Make mask tensor containing indexes of padded part. + + TODO:: + This function **assumes** that the model uses + a subsampling factor of 4 or 2. We should remove that + assumption later. + + Args: + max_len: + Maximum length of input features. + CAUTION: It is the length after subsampling. + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + (CAUTION: It contains length information, i.e., start and number of + frames, before subsampling) + + Returns: + Tensor: Mask tensor of dimension (batch_size, input_length), + True denote the masked indices. + """ + if supervisions is None: + return None + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"], + supervisions["num_frames"], + ), + 1, + ).to(torch.int32) + + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] + for idx in range(supervision_segments.size(0)): + # Note: TorchScript doesn't allow to unpack tensors as tuples + sequence_idx = supervision_segments[idx, 0].item() + start_frame = supervision_segments[idx, 1].item() + num_frames = supervision_segments[idx, 2].item() + lengths[sequence_idx] = start_frame + num_frames + + if subsampling_factor == 4: + lengths = [((i - 1) // 2 - 1) // 2 for i in lengths] + elif subsampling_factor == 2: + lengths = [(i - 1) // 2 - 2 for i in lengths] + bs = int(len(lengths)) + seq_range = torch.arange(0, max_len, dtype=torch.int64) + seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len) + # Note: TorchScript doesn't implement Tensor.new() + seq_length_expand = torch.tensor( + lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype + ).unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + + return mask + + +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: + """Generate a length mask for input. + + The masked position are filled with True, + Unmasked positions are filled with False. + + Args: + ys_pad: + padded tensor of dimension (batch_size, input_length). + ignore_id: + the ignored number (the padding number) in ys_pad + + Returns: + Tensor: + a bool tensor of the same shape as the input tensor. + """ + ys_mask = ys_pad == ignore_id + return ys_mask + + +def generate_square_subsequent_mask(sz: int) -> torch.Tensor: + """Generate a square mask for the sequence. The masked positions are + filled with float('-inf'). Unmasked positions are filled with float(0.0). + The mask can be used for masked self-attention. + + For instance, if sz is 3, it returns:: + + tensor([[0., -inf, -inf], + [0., 0., -inf], + [0., 0., 0]]) + + Args: + sz: mask size + + Returns: + A square mask of dimension (sz, sz) + """ + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = ( + mask.float() + .masked_fill(mask == 0, float("-inf")) + .masked_fill(mask == 1, float(0.0)) + ) + return mask + + +def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: + """Prepend sos_id to each utterance. + + Args: + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + sos_id: + The ID of the SOS token. + + Return: + Return a new list-of-list, where each sublist starts + with SOS ID. + """ + return [[sos_id] + utt for utt in token_ids] + + +def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: + """Append eos_id to each utterance. + + Args: + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + eos_id: + The ID of the EOS token. + + Return: + Return a new list-of-list, where each sublist ends + with EOS ID. + """ + return [utt + [eos_id] for utt in token_ids] + + +def tolist(t: torch.Tensor) -> List[int]: + """Used by jit""" + return torch.jit.annotate(List[int], t.tolist()) diff --git a/egs/librispeech/WSASR/figures/del.png b/egs/librispeech/WSASR/figures/del.png new file mode 100644 index 000000000..38973980b Binary files /dev/null and b/egs/librispeech/WSASR/figures/del.png differ diff --git a/egs/librispeech/WSASR/figures/ins.png b/egs/librispeech/WSASR/figures/ins.png new file mode 100644 index 000000000..2d0e807a9 Binary files /dev/null and b/egs/librispeech/WSASR/figures/ins.png differ diff --git a/egs/librispeech/WSASR/figures/otc_emission.drawio.png b/egs/librispeech/WSASR/figures/otc_emission.drawio.png new file mode 100644 index 000000000..6cea5531d Binary files /dev/null and b/egs/librispeech/WSASR/figures/otc_emission.drawio.png differ diff --git a/egs/librispeech/WSASR/figures/otc_g.png b/egs/librispeech/WSASR/figures/otc_g.png new file mode 100644 index 000000000..ebad49180 Binary files /dev/null and b/egs/librispeech/WSASR/figures/otc_g.png differ diff --git a/egs/librispeech/WSASR/figures/otc_training_graph.drawio.png b/egs/librispeech/WSASR/figures/otc_training_graph.drawio.png new file mode 100644 index 000000000..8978158d8 Binary files /dev/null and b/egs/librispeech/WSASR/figures/otc_training_graph.drawio.png differ diff --git a/egs/librispeech/WSASR/figures/sub.png b/egs/librispeech/WSASR/figures/sub.png new file mode 100644 index 000000000..5674e9feb Binary files /dev/null and b/egs/librispeech/WSASR/figures/sub.png differ diff --git a/egs/librispeech/WSASR/local/compile_hlg.py b/egs/librispeech/WSASR/local/compile_hlg.py new file mode 100755 index 000000000..63791f4cc --- /dev/null +++ b/egs/librispeech/WSASR/local/compile_hlg.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input lang_dir and generates HLG from + + - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt + - L, the lexicon, built from lang_dir/L_disambig.pt + + Caution: We use a lexicon that contains disambiguation symbols + + - G, the LM, built from data/lm/G_n_gram.fst.txt + +The generated HLG is saved in $lang_dir/HLG.pt +""" +import argparse +import logging +from pathlib import Path + +import k2 +import torch + +from icefall.lexicon import Lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lm", + type=str, + default="G_3_gram", + help="""Stem name for LM used in HLG compiling. + """, + ) + parser.add_argument( + "--lm-dir", + type=str, + help="""LM directory. + """, + ) + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + """, + ) + + return parser.parse_args() + + +def compile_HLG(lm_dir: str, lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: + """ + Args: + lang_dir: + The language directory, e.g., data/lang_phone or data/lang_bpe_5000. + lm: + The language stem base name. + + Return: + An FSA representing HLG. + """ + lexicon = Lexicon(lang_dir) + max_token_id = max(lexicon.tokens) + logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") + H = k2.ctc_topo(max_token_id) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + + if Path(f"{lm_dir}/{lm}.pt").is_file(): + logging.info(f"Loading pre-compiled {lm}") + d = torch.load(f"{lm_dir}/{lm}.pt") + G = k2.Fsa.from_dict(d) + else: + logging.info(f"Loading {lm}.fst.txt") + with open(f"{lm_dir}/{lm}.fst.txt") as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + torch.save(G.as_dict(), f"{lm_dir}/{lm}.pt") + + first_token_disambig_id = lexicon.token_table["#0"] + first_word_disambig_id = lexicon.word_table["#0"] + + L = k2.arc_sort(L) + G = k2.arc_sort(G) + + logging.info("Intersecting L and G") + LG = k2.compose(L, G) + logging.info(f"LG shape: {LG.shape}") + + logging.info("Connecting LG") + LG = k2.connect(LG) + logging.info(f"LG shape after k2.connect: {LG.shape}") + + logging.info(type(LG.aux_labels)) + logging.info("Determinizing LG") + + LG = k2.determinize(LG) + logging.info(type(LG.aux_labels)) + + logging.info("Connecting LG after k2.determinize") + LG = k2.connect(LG) + + logging.info("Removing disambiguation symbols on LG") + + LG.labels[LG.labels >= first_token_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set LG.properties to None + LG.__dict__["_properties"] = None + + assert isinstance(LG.aux_labels, k2.RaggedTensor) + LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 + + LG = k2.remove_epsilon(LG) + logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") + + LG = k2.connect(LG) + LG.aux_labels = LG.aux_labels.remove_values_eq(0) + + logging.info("Arc sorting LG") + LG = k2.arc_sort(LG) + + logging.info("Composing H and LG") + # CAUTION: The name of the inner_labels is fixed + # to `tokens`. If you want to change it, please + # also change other places in icefall that are using + # it. + HLG = k2.compose(H, LG, inner_labels="tokens") + + logging.info("Connecting LG") + HLG = k2.connect(HLG) + + logging.info("Arc sorting LG") + HLG = k2.arc_sort(HLG) + logging.info(f"HLG.shape: {HLG.shape}") + + return HLG + + +def main(): + args = get_args() + lm_dir = Path(args.lm_dir) + lang_dir = Path(args.lang_dir) + + if (lang_dir / "HLG.pt").is_file(): + logging.info(f"{lang_dir}/HLG.pt already exists - skipping") + return + + logging.info(f"Processing {lang_dir}") + + HLG = compile_HLG(lm_dir, lang_dir, args.lm) + logging.info(f"Saving HLG.pt to {lang_dir}") + torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/WSASR/local/compute_fbank_librispeech.py b/egs/librispeech/WSASR/local/compute_fbank_librispeech.py new file mode 100755 index 000000000..a387d54c9 --- /dev/null +++ b/egs/librispeech/WSASR/local/compute_fbank_librispeech.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the LibriSpeech dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import sentencepiece as spm +import torch +from filter_cuts import filter_cuts +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor, str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to the bpe.model. If not None, we will remove short and + long utterances before extracting features""", + ) + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=True, + help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", + ) + + return parser.parse_args() + + +def compute_fbank_librispeech( + bpe_model: Optional[str] = None, + dataset: Optional[str] = None, + perturb_speed: Optional[bool] = True, +): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + num_jobs = min(15, os.cpu_count()) + num_mel_bins = 80 + + if bpe_model: + logging.info(f"Loading {bpe_model}") + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + + if dataset is None: + dataset_parts = ( + "dev-clean", + "dev-other", + "test-clean", + "test-other", + "train-clean-100", + ) + else: + dataset_parts = dataset.split(" ", -1) + + prefix = "librispeech" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix=prefix, + suffix=suffix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + with get_executor() as ex: # Initialize the executor only once. + for partition, m in manifests.items(): + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{partition} already exists - skipping.") + continue + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + + if "train" in partition: + if bpe_model: + cut_set = filter_cuts(cut_set, sp) + if perturb_speed: + logging.info(f"Doing speed perturb") + cut_set = ( + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) + ) + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + compute_fbank_librispeech( + bpe_model=args.bpe_model, + dataset=args.dataset, + perturb_speed=args.perturb_speed, + ) diff --git a/egs/librispeech/WSASR/local/compute_ssl_librispeech.py b/egs/librispeech/WSASR/local/compute_ssl_librispeech.py new file mode 100755 index 000000000..f405c468c --- /dev/null +++ b/egs/librispeech/WSASR/local/compute_ssl_librispeech.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# 2023 Johns Hopkins University (author: Dongji Gao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the LibriSpeech dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import logging +import os +from pathlib import Path + +import torch +from lhotse import S3PRLSSL, CutSet, NumpyFilesWriter, S3PRLSSLConfig +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_ssl_librispeech(): + src_dir = Path("data/manifests") + output_dir = Path("data/ssl") + num_jobs = 1 + + dataset_parts = ( + "dev-clean", + "dev-other", + "test-clean", + "test-other", + "train-clean-100", + ) + prefix = "librispeech" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix=prefix, + suffix=suffix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + extractor = S3PRLSSL(S3PRLSSLConfig(ssl_model="wav2vec2", device="cuda")) + + with get_executor() as ex: # Initialize the executor only once. + for partition, m in manifests.items(): + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{partition} already exists - skipping.") + continue + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + storage_type=NumpyFilesWriter, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + compute_ssl_librispeech() diff --git a/egs/librispeech/WSASR/local/filter_cuts.py b/egs/librispeech/WSASR/local/filter_cuts.py new file mode 100644 index 000000000..fbcc9e24a --- /dev/null +++ b/egs/librispeech/WSASR/local/filter_cuts.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script removes short and long utterances from a cutset. + +Caution: + You may need to tune the thresholds for your own dataset. + +Usage example: + + python3 ./local/filter_cuts.py \ + --bpe-model data/lang_bpe_500/bpe.model \ + --in-cuts data/fbank/librispeech_cuts_test-clean.jsonl.gz \ + --out-cuts data/fbank-filtered/librispeech_cuts_test-clean.jsonl.gz +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +from lhotse import CutSet, load_manifest_lazy +from lhotse.cut import Cut + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model", + type=Path, + help="Path to the bpe.model", + ) + + parser.add_argument( + "--in-cuts", + type=Path, + help="Path to the input cutset", + ) + + parser.add_argument( + "--out-cuts", + type=Path, + help="Path to the output cutset", + ) + + return parser.parse_args() + + +def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): + total = 0 # number of total utterances before removal + removed = 0 # number of removed utterances + + def remove_short_and_long_utterances(c: Cut): + """Return False to exclude the input cut""" + nonlocal removed, total + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ./display_manifest_statistics.py + # + # You should use ./display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + total += 1 + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + removed += 1 + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./pruned_transducer_stateless2/conformer.py, the + # conv module uses the following expression + # for subsampling + if c.num_frames is None: + num_frames = c.duration * 100 # approximate + else: + num_frames = c.num_frames + + T = ((num_frames - 1) // 2 - 1) // 2 + # Note: for ./lstm_transducer_stateless/lstm.py, the formula is + # T = ((num_frames - 3) // 2 - 1) // 2 + + # Note: for ./pruned_transducer_stateless7/zipformer.py, the formula is + # T = ((num_frames - 7) // 2 + 1) // 2 + + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + removed += 1 + return False + + return True + + # We use to_eager() here so that we can print out the value of total + # and removed below. + ans = cut_set.filter(remove_short_and_long_utterances).to_eager() + ratio = removed / total * 100 + logging.info( + f"Removed {removed} cuts from {total} cuts. {ratio:.3f}% data is removed." + ) + return ans + + +def main(): + args = get_args() + logging.info(vars(args)) + + if args.out_cuts.is_file(): + logging.info(f"{args.out_cuts} already exists - skipping") + return + + assert args.in_cuts.is_file(), f"{args.in_cuts} does not exist" + assert args.bpe_model.is_file(), f"{args.bpe_model} does not exist" + + sp = spm.SentencePieceProcessor() + sp.load(str(args.bpe_model)) + + cut_set = load_manifest_lazy(args.in_cuts) + assert isinstance(cut_set, CutSet) + + cut_set = filter_cuts(cut_set, sp) + logging.info(f"Saving to {args.out_cuts}") + args.out_cuts.parent.mkdir(parents=True, exist_ok=True) + cut_set.to_file(args.out_cuts) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/WSASR/local/get_words_from_lexicon.py b/egs/librispeech/WSASR/local/get_words_from_lexicon.py new file mode 100755 index 000000000..0cc740b36 --- /dev/null +++ b/egs/librispeech/WSASR/local/get_words_from_lexicon.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 + +import argparse +from pathlib import Path + +from icefall.lexicon import read_lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain a file lexicon.txt. + Generated files by this script are saved into this directory. + """, + ) + + parser.add_argument( + "--otc-token", + type=str, + help="OTC token to be added to words.txt", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + otc_token = args.otc_token + + lexicon = read_lexicon(lang_dir / "lexicon.txt") + ans = set() + for word, _ in lexicon: + ans.add(word) + sorted_ans = sorted(list(ans)) + words = [""] + sorted_ans + [otc_token] + ["#0", "", ""] + + words_file = lang_dir / "words.txt" + with open(words_file, "w") as wf: + for i, word in enumerate(words): + wf.write(f"{word} {i}\n") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/WSASR/local/make_error_cutset.py b/egs/librispeech/WSASR/local/make_error_cutset.py new file mode 100755 index 000000000..8463a380e --- /dev/null +++ b/egs/librispeech/WSASR/local/make_error_cutset.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 + +# Copyright 2023 Johns Hopkins University (author: Dongji Gao) + +import argparse +import random +from pathlib import Path +from typing import List + +from lhotse import CutSet, load_manifest +from lhotse.cut.base import Cut + +from icefall.utils import str2bool + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--input-cutset", + type=str, + help="Supervision manifest that contains verbatim transcript", + ) + + parser.add_argument( + "--words-file", + type=str, + help="words.txt file", + ) + + parser.add_argument( + "--otc-token", + type=str, + help="OTC token in words.txt", + ) + + parser.add_argument( + "--sub-error-rate", + type=float, + default=0.0, + help="Substitution error rate", + ) + + parser.add_argument( + "--ins-error-rate", + type=float, + default=0.0, + help="Insertion error rate", + ) + + parser.add_argument( + "--del-error-rate", + type=float, + default=0.0, + help="Deletion error rate", + ) + + parser.add_argument( + "--output-cutset", + type=str, + default="", + help="Supervision manifest that contains modified non-verbatim transcript", + ) + + parser.add_argument("--verbose", type=str2bool, help="show details of errors") + return parser.parse_args() + + +def check_args(args): + total_error_rate = args.sub_error_rate + args.ins_error_rate + args.del_error_rate + assert args.sub_error_rate >= 0 and args.sub_error_rate <= 1.0 + assert args.ins_error_rate >= 0 and args.sub_error_rate <= 1.0 + assert args.del_error_rate >= 0 and args.sub_error_rate <= 1.0 + assert total_error_rate <= 1.0 + + +def get_word_list(token_path: str) -> List: + word_list = [] + with open(Path(token_path), "r") as tp: + for line in tp.readlines(): + token = line.split()[0] + assert token not in word_list + word_list.append(token) + return word_list + + +def modify_cut_text( + cut: Cut, + words_list: List, + non_words: List, + sub_ratio: float = 0.0, + ins_ratio: float = 0.0, + del_ratio: float = 0.0, +): + text = cut.supervisions[0].text + text_list = text.split() + + # We save the modified information of the original verbatim text for debugging + marked_verbatim_text_list = [] + modified_text_list = [] + + del_index_set = set() + sub_index_set = set() + ins_index_set = set() + + # We follow the order: deletion -> substitution -> insertion + for token in text_list: + marked_token = token + modified_token = token + + prob = random.random() + + if prob <= del_ratio: + marked_token = f"-{token}-" + modified_token = "" + elif prob <= del_ratio + sub_ratio + ins_ratio: + if prob <= del_ratio + sub_ratio: + marked_token = f"[{token}]" + else: + marked_verbatim_text_list.append(marked_token) + modified_text_list.append(modified_token) + marked_token = "[]" + + # get new_token + while ( + modified_token == token + or modified_token in non_words + or modified_token.startswith("#") + ): + modified_token = random.choice(words_list) + + marked_verbatim_text_list.append(marked_token) + modified_text_list.append(modified_token) + + marked_text = " ".join(marked_verbatim_text_list) + modified_text = " ".join(modified_text_list) + + if not hasattr(cut.supervisions[0], "verbatim_text"): + cut.supervisions[0].verbatim_text = marked_text + cut.supervisions[0].text = modified_text + + return cut + + +def main(): + args = get_args() + check_args(args) + + otc_token = args.otc_token + non_words = set(("sil", "", "")) + non_words.add(otc_token) + + words_list = get_word_list(args.words_file) + cutset = load_manifest(Path(args.input_cutset)) + + cuts = [] + + for cut in cutset: + modified_cut = modify_cut_text( + cut=cut, + words_list=words_list, + non_words=non_words, + sub_ratio=args.sub_error_rate, + ins_ratio=args.ins_error_rate, + del_ratio=args.del_error_rate, + ) + cuts.append(modified_cut) + + output_cutset = CutSet.from_cuts(cuts) + output_cutset.to_file(args.output_cutset) + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/WSASR/local/prepare_lang.py b/egs/librispeech/WSASR/local/prepare_lang.py new file mode 100755 index 000000000..d913756a1 --- /dev/null +++ b/egs/librispeech/WSASR/local/prepare_lang.py @@ -0,0 +1,413 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input a lexicon file "data/lang_phone/lexicon.txt" +consisting of words and tokens (i.e., phones) and does the following: + +1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt + +2. Generate tokens.txt, the token table mapping a token to a unique integer. + +3. Generate words.txt, the word table mapping a word to a unique integer. + +4. Generate L.pt, in k2 format. It can be loaded by + + d = torch.load("L.pt") + lexicon = k2.Fsa.from_dict(d) + +5. Generate L_disambig.pt, in k2 format. +""" +import argparse +import math +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import k2 +import torch + +from icefall.lexicon import read_lexicon, write_lexicon +from icefall.utils import str2bool + +Lexicon = List[Tuple[str, List[str]]] + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain a file lexicon.txt. + Generated files by this script are saved into this directory. + """, + ) + + parser.add_argument( + "--debug", + type=str2bool, + default=False, + help="""True for debugging, which will generate + a visualization of the lexicon FST. + + Caution: If your lexicon contains hundreds of thousands + of lines, please set it to False! + """, + ) + + return parser.parse_args() + + +def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: + """Write a symbol to ID mapping to a file. + + Note: + No need to implement `read_mapping` as it can be done + through :func:`k2.SymbolTable.from_file`. + + Args: + filename: + Filename to save the mapping. + sym2id: + A dict mapping symbols to IDs. + Returns: + Return None. + """ + with open(filename, "w", encoding="utf-8") as f: + for sym, i in sym2id.items(): + f.write(f"{sym} {i}\n") + + +def get_tokens(lexicon: Lexicon) -> List[str]: + """Get tokens from a lexicon. + + Args: + lexicon: + It is the return value of :func:`read_lexicon`. + Returns: + Return a list of unique tokens. + """ + ans = set() + for _, tokens in lexicon: + ans.update(tokens) + sorted_ans = sorted(list(ans)) + return sorted_ans + + +def get_words(lexicon: Lexicon) -> List[str]: + """Get words from a lexicon. + + Args: + lexicon: + It is the return value of :func:`read_lexicon`. + Returns: + Return a list of unique words. + """ + ans = set() + for word, _ in lexicon: + ans.add(word) + sorted_ans = sorted(list(ans)) + return sorted_ans + + +def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]: + """It adds pseudo-token disambiguation symbols #1, #2 and so on + at the ends of tokens to ensure that all pronunciations are different, + and that none is a prefix of another. + + See also add_lex_disambig.pl from kaldi. + + Args: + lexicon: + It is returned by :func:`read_lexicon`. + Returns: + Return a tuple with two elements: + + - The output lexicon with disambiguation symbols + - The ID of the max disambiguation symbol that appears + in the lexicon + """ + + # (1) Work out the count of each token-sequence in the + # lexicon. + count = defaultdict(int) + for _, tokens in lexicon: + count[" ".join(tokens)] += 1 + + # (2) For each left sub-sequence of each token-sequence, note down + # that it exists (for identifying prefixes of longer strings). + issubseq = defaultdict(int) + for _, tokens in lexicon: + tokens = tokens.copy() + tokens.pop() + while tokens: + issubseq[" ".join(tokens)] = 1 + tokens.pop() + + # (3) For each entry in the lexicon: + # if the token sequence is unique and is not a + # prefix of another word, no disambig symbol. + # Else output #1, or #2, #3, ... if the same token-seq + # has already been assigned a disambig symbol. + ans = [] + + # We start with #1 since #0 has its own purpose + first_allowed_disambig = 1 + max_disambig = first_allowed_disambig - 1 + last_used_disambig_symbol_of = defaultdict(int) + + for word, tokens in lexicon: + tokenseq = " ".join(tokens) + assert tokenseq != "" + if issubseq[tokenseq] == 0 and count[tokenseq] == 1: + ans.append((word, tokens)) + continue + + cur_disambig = last_used_disambig_symbol_of[tokenseq] + if cur_disambig == 0: + cur_disambig = first_allowed_disambig + else: + cur_disambig += 1 + + if cur_disambig > max_disambig: + max_disambig = cur_disambig + last_used_disambig_symbol_of[tokenseq] = cur_disambig + tokenseq += f" #{cur_disambig}" + ans.append((word, tokenseq.split())) + return ans, max_disambig + + +def generate_id_map(symbols: List[str]) -> Dict[str, int]: + """Generate ID maps, i.e., map a symbol to a unique ID. + + Args: + symbols: + A list of unique symbols. + Returns: + A dict containing the mapping between symbols and IDs. + """ + return {sym: i for i, sym in enumerate(symbols)} + + +def add_self_loops( + arcs: List[List[Any]], disambig_token: int, disambig_word: int +) -> List[List[Any]]: + """Adds self-loops to states of an FST to propagate disambiguation symbols + through it. They are added on each state with non-epsilon output symbols + on at least one arc out of the state. + + See also fstaddselfloops.pl from Kaldi. One difference is that + Kaldi uses OpenFst style FSTs and it has multiple final states. + This function uses k2 style FSTs and it does not need to add self-loops + to the final state. + + The input label of a self-loop is `disambig_token`, while the output + label is `disambig_word`. + + Args: + arcs: + A list-of-list. The sublist contains + `[src_state, dest_state, label, aux_label, score]` + disambig_token: + It is the token ID of the symbol `#0`. + disambig_word: + It is the word ID of the symbol `#0`. + + Return: + Return new `arcs` containing self-loops. + """ + states_needs_self_loops = set() + for arc in arcs: + src, dst, ilabel, olabel, score = arc + if olabel != 0: + states_needs_self_loops.add(src) + + ans = [] + for s in states_needs_self_loops: + ans.append([s, s, disambig_token, disambig_word, 0]) + + return arcs + ans + + +def lexicon_to_fst( + lexicon: Lexicon, + token2id: Dict[str, int], + word2id: Dict[str, int], + sil_token: str = "SIL", + sil_prob: float = 0.5, + need_self_loops: bool = False, +) -> k2.Fsa: + """Convert a lexicon to an FST (in k2 format) with optional silence at + the beginning and end of each word. + + Args: + lexicon: + The input lexicon. See also :func:`read_lexicon` + token2id: + A dict mapping tokens to IDs. + word2id: + A dict mapping words to IDs. + sil_token: + The silence token. + sil_prob: + The probability for adding a silence at the beginning and end + of the word. + need_self_loops: + If True, add self-loop to states with non-epsilon output symbols + on at least one arc out of the state. The input label for this + self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. + Returns: + Return an instance of `k2.Fsa` representing the given lexicon. + """ + assert sil_prob > 0.0 and sil_prob < 1.0 + # CAUTION: we use score, i.e, negative cost. + sil_score = math.log(sil_prob) + no_sil_score = math.log(1.0 - sil_prob) + + start_state = 0 + loop_state = 1 # words enter and leave from here + sil_state = 2 # words terminate here when followed by silence; this state + # has a silence transition to loop_state. + next_state = 3 # the next un-allocated state, will be incremented as we go. + arcs = [] + + assert token2id[""] == 0 + assert word2id[""] == 0 + + eps = 0 + + sil_token = token2id[sil_token] + + arcs.append([start_state, loop_state, eps, eps, no_sil_score]) + arcs.append([start_state, sil_state, eps, eps, sil_score]) + arcs.append([sil_state, loop_state, sil_token, eps, 0]) + + for word, tokens in lexicon: + assert len(tokens) > 0, f"{word} has no pronunciations" + cur_state = loop_state + + word = word2id[word] + tokens = [token2id[i] for i in tokens] + + for i in range(len(tokens) - 1): + w = word if i == 0 else eps + arcs.append([cur_state, next_state, tokens[i], w, 0]) + + cur_state = next_state + next_state += 1 + + # now for the last token of this word + # It has two out-going arcs, one to the loop state, + # the other one to the sil_state. + i = len(tokens) - 1 + w = word if i == 0 else eps + arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score]) + arcs.append([cur_state, sil_state, tokens[i], w, sil_score]) + + if need_self_loops: + disambig_token = token2id["#0"] + disambig_word = word2id["#0"] + arcs = add_self_loops( + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, + ) + + final_state = next_state + arcs.append([loop_state, final_state, -1, -1, 0]) + arcs.append([final_state]) + + arcs = sorted(arcs, key=lambda arc: arc[0]) + arcs = [[str(i) for i in arc] for arc in arcs] + arcs = [" ".join(arc) for arc in arcs] + arcs = "\n".join(arcs) + + fsa = k2.Fsa.from_str(arcs, acceptor=False) + return fsa + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + lexicon_filename = lang_dir / "lexicon.txt" + sil_token = "SIL" + sil_prob = 0.5 + + lexicon = read_lexicon(lexicon_filename) + tokens = get_tokens(lexicon) + words = get_words(lexicon) + + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + + for i in range(max_disambig + 1): + disambig = f"#{i}" + assert disambig not in tokens + tokens.append(f"#{i}") + + assert "" not in tokens + tokens = [""] + tokens + + assert "" not in words + assert "#0" not in words + assert "" not in words + assert "" not in words + + words = [""] + words + ["#0", "", ""] + + token2id = generate_id_map(tokens) + word2id = generate_id_map(words) + + write_mapping(lang_dir / "tokens.txt", token2id) + write_mapping(lang_dir / "words.txt", word2id) + write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) + + L = lexicon_to_fst( + lexicon, + token2id=token2id, + word2id=word2id, + sil_token=sil_token, + sil_prob=sil_prob, + ) + + L_disambig = lexicon_to_fst( + lexicon_disambig, + token2id=token2id, + word2id=word2id, + sil_token=sil_token, + sil_prob=sil_prob, + need_self_loops=True, + ) + torch.save(L.as_dict(), lang_dir / "L.pt") + torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") + + if args.debug: + labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") + aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") + + L.labels_sym = labels_sym + L.aux_labels_sym = aux_labels_sym + L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") + + L_disambig.labels_sym = labels_sym + L_disambig.aux_labels_sym = aux_labels_sym + L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/WSASR/local/prepare_otc_lang_bpe.py b/egs/librispeech/WSASR/local/prepare_otc_lang_bpe.py new file mode 100755 index 000000000..415bdff6f --- /dev/null +++ b/egs/librispeech/WSASR/local/prepare_otc_lang_bpe.py @@ -0,0 +1,295 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# 2023 Johns Hopkins University (author: Dongji Gao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +""" + +This script takes as input `lang_dir`, which should contain:: + + - lang_dir/bpe.model, + - lang_dir/words.txt + +and generates the following files in the directory `lang_dir`: + + - lexicon.txt + - lexicon_disambig.txt + - L.pt + - L_disambig.pt + - tokens.txt +""" + +import argparse +from pathlib import Path +from typing import Dict, List, Tuple + +import k2 +import sentencepiece as spm +import torch +from prepare_lang import ( + Lexicon, + add_disambig_symbols, + add_self_loops, + write_lexicon, + write_mapping, +) + +from icefall.utils import str2bool + + +def lexicon_to_fst_no_sil( + lexicon: Lexicon, + token2id: Dict[str, int], + word2id: Dict[str, int], + need_self_loops: bool = False, +) -> k2.Fsa: + """Convert a lexicon to an FST (in k2 format). + + Args: + lexicon: + The input lexicon. See also :func:`read_lexicon` + token2id: + A dict mapping tokens to IDs. + word2id: + A dict mapping words to IDs. + need_self_loops: + If True, add self-loop to states with non-epsilon output symbols + on at least one arc out of the state. The input label for this + self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. + Returns: + Return an instance of `k2.Fsa` representing the given lexicon. + """ + loop_state = 0 # words enter and leave from here + next_state = 1 # the next un-allocated state, will be incremented as we go + + arcs = [] + + # The blank symbol is defined in local/train_bpe_model.py + assert token2id[""] == 0 + assert word2id[""] == 0 + + eps = 0 + + for word, pieces in lexicon: + assert len(pieces) > 0, f"{word} has no pronunciations" + cur_state = loop_state + + word = word2id[word] + pieces = [token2id[i] for i in pieces] + + for i in range(len(pieces) - 1): + w = word if i == 0 else eps + arcs.append([cur_state, next_state, pieces[i], w, 0]) + + cur_state = next_state + next_state += 1 + + # now for the last piece of this word + i = len(pieces) - 1 + w = word if i == 0 else eps + arcs.append([cur_state, loop_state, pieces[i], w, 0]) + + if need_self_loops: + disambig_token = token2id["#0"] + disambig_word = word2id["#0"] + arcs = add_self_loops( + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, + ) + + final_state = next_state + arcs.append([loop_state, final_state, -1, -1, 0]) + arcs.append([final_state]) + + arcs = sorted(arcs, key=lambda arc: arc[0]) + arcs = [[str(i) for i in arc] for arc in arcs] + arcs = [" ".join(arc) for arc in arcs] + arcs = "\n".join(arcs) + + fsa = k2.Fsa.from_str(arcs, acceptor=False) + return fsa + + +def generate_otc_lexicon( + model_file: str, + words: List[str], + oov: str, + otc_token: str, +) -> Tuple[Lexicon, Dict[str, int]]: + """Generate a lexicon from a BPE model. + + Args: + model_file: + Path to a sentencepiece model. + words: + A list of strings representing words. + oov: + The out of vocabulary word in lexicon. + otc_token: + The OTC token in lexicon. + Returns: + Return a tuple with two elements: + - A dict whose keys are words and values are the corresponding + word pieces. + - A dict representing the token symbol, mapping from tokens to IDs. + """ + sp = spm.SentencePieceProcessor() + sp.load(str(model_file)) + + # Convert word to word piece IDs instead of word piece strings + # to avoid OOV tokens. + words_pieces_ids: List[List[int]] = sp.encode(words, out_type=int) + + # Now convert word piece IDs back to word piece strings. + words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids] + + lexicon = [] + for word, pieces in zip(words, words_pieces): + lexicon.append((word, pieces)) + + lexicon.append((oov, ["▁", sp.id_to_piece(sp.unk_id())])) + token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())} + + # Add OTC token to the last. + lexicon.append((otc_token, [f"▁{otc_token}"])) + otc_token_index = len(token2id) + token2id[f"▁{otc_token}"] = otc_token_index + + return lexicon, token2id + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain the bpe.model and words.txt + """, + ) + + parser.add_argument( + "--oov", + type=str, + default="", + help="The out of vocabulary word in lexicon.", + ) + + parser.add_argument( + "--otc-token", + type=str, + default="", + help="The OTC token in lexicon.", + ) + + parser.add_argument( + "--debug", + type=str2bool, + default=False, + help="""True for debugging, which will generate + a visualization of the lexicon FST. + + Caution: If your lexicon contains hundreds of thousands + of lines, please set it to False! + + See "test/test_bpe_lexicon.py" for usage. + """, + ) + + return parser.parse_args() + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + model_file = lang_dir / "bpe.model" + otc_token = args.otc_token + + word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") + + words = word_sym_table.symbols + + excluded = [ + "", + "!SIL", + "", + args.oov, + otc_token, + "#0", + "", + "", + ] + + for w in excluded: + if w in words: + words.remove(w) + + lexicon, token_sym_table = generate_otc_lexicon( + model_file, words, args.oov, otc_token + ) + + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + + next_token_id = max(token_sym_table.values()) + 1 + for i in range(max_disambig + 1): + disambig = f"#{i}" + assert disambig not in token_sym_table + token_sym_table[disambig] = next_token_id + next_token_id += 1 + + word_sym_table.add("#0") + word_sym_table.add("") + word_sym_table.add("") + + write_mapping(lang_dir / "tokens.txt", token_sym_table) + + write_lexicon(lang_dir / "lexicon.txt", lexicon) + write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) + + L = lexicon_to_fst_no_sil( + lexicon, + token2id=token_sym_table, + word2id=word_sym_table, + ) + + L_disambig = lexicon_to_fst_no_sil( + lexicon_disambig, + token2id=token_sym_table, + word2id=word_sym_table, + need_self_loops=True, + ) + torch.save(L.as_dict(), lang_dir / "L.pt") + torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") + + if args.debug: + labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") + aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") + + L.labels_sym = labels_sym + L.aux_labels_sym = aux_labels_sym + L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") + + L_disambig.labels_sym = labels_sym + L_disambig.aux_labels_sym = aux_labels_sym + L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/WSASR/local/train_bpe_model.py b/egs/librispeech/WSASR/local/train_bpe_model.py new file mode 100755 index 000000000..43142aee4 --- /dev/null +++ b/egs/librispeech/WSASR/local/train_bpe_model.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# You can install sentencepiece via: +# +# pip install sentencepiece +# +# Due to an issue reported in +# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 +# +# Please install a version >=0.1.96 + +import argparse +import shutil +from pathlib import Path + +import sentencepiece as spm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + The generated bpe.model is saved to this directory. + """, + ) + + parser.add_argument( + "--transcript", + type=str, + help="Training transcript.", + ) + + parser.add_argument( + "--vocab-size", + type=int, + help="Vocabulary size for BPE training", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + vocab_size = args.vocab_size + lang_dir = Path(args.lang_dir) + + model_type = "unigram" + + model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" + train_text = args.transcript + character_coverage = 1.0 + input_sentence_size = 100000000 + + user_defined_symbols = ["", ""] + unk_id = len(user_defined_symbols) + # Note: unk_id is fixed to 2. + # If you change it, you should also change other + # places that are using it. + + model_file = Path(model_prefix + ".model") + if not model_file.is_file(): + spm.SentencePieceTrainer.train( + input=train_text, + vocab_size=vocab_size, + model_type=model_type, + model_prefix=model_prefix, + input_sentence_size=input_sentence_size, + character_coverage=character_coverage, + user_defined_symbols=user_defined_symbols, + unk_id=unk_id, + bos_id=-1, + eos_id=-1, + ) + else: + print(f"{model_file} exists - skipping") + return + + shutil.copyfile(model_file, f"{lang_dir}/bpe.model") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/WSASR/local/validate_bpe_lexicon.py b/egs/librispeech/WSASR/local/validate_bpe_lexicon.py new file mode 100755 index 000000000..16a489c11 --- /dev/null +++ b/egs/librispeech/WSASR/local/validate_bpe_lexicon.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script checks that there are no OOV tokens in the BPE-based lexicon. + +Usage example: + + python3 ./local/validate_bpe_lexicon.py \ + --lexicon /path/to/lexicon.txt \ + --bpe-model /path/to/bpe.model +""" + +import argparse +from pathlib import Path +from typing import List, Tuple + +import sentencepiece as spm + +from icefall.lexicon import read_lexicon + +# Map word to word pieces +Lexicon = List[Tuple[str, List[str]]] + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--lexicon", + required=True, + type=Path, + help="Path to lexicon.txt", + ) + + parser.add_argument( + "--bpe-model", + required=True, + type=Path, + help="Path to bpe.model", + ) + + parser.add_argument( + "--otc-token", + required=True, + type=str, + help="OTC token", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + assert args.lexicon.is_file(), args.lexicon + assert args.bpe_model.is_file(), args.bpe_model + + lexicon = read_lexicon(args.lexicon) + + sp = spm.SentencePieceProcessor() + sp.load(str(args.bpe_model)) + + word_pieces = set(sp.id_to_piece(list(range(sp.vocab_size())))) + word_pieces.add(f"▁{args.otc_token}") + for word, pieces in lexicon: + for p in pieces: + if p not in word_pieces: + raise ValueError(f"The word {word} contains an OOV token {p}") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/WSASR/local/validate_manifest.py b/egs/librispeech/WSASR/local/validate_manifest.py new file mode 100755 index 000000000..f620b91ea --- /dev/null +++ b/egs/librispeech/WSASR/local/validate_manifest.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script checks the following assumptions of the generated manifest: + +- Single supervision per cut +- Supervision time bounds are within cut time bounds + +We will add more checks later if needed. + +Usage example: + + python3 ./local/validate_manifest.py \ + ./data/fbank/librispeech_cuts_train-clean-100.jsonl.gz + +""" + +import argparse +import logging +from pathlib import Path + +from lhotse import CutSet, load_manifest_lazy +from lhotse.cut import Cut + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "manifest", + type=Path, + help="Path to the manifest file", + ) + + return parser.parse_args() + + +def validate_one_supervision_per_cut(c: Cut): + if len(c.supervisions) != 1: + raise ValueError(f"{c.id} has {len(c.supervisions)} supervisions") + + +def validate_supervision_and_cut_time_bounds(c: Cut): + s = c.supervisions[0] + if s.start < c.start: + raise ValueError( + f"{c.id}: Supervision start time {s.start} is less " + f"than cut start time {c.start}" + ) + + if s.end > c.end: + raise ValueError( + f"{c.id}: Supervision end time {s.end} is larger " + f"than cut end time {c.end}" + ) + + +def main(): + args = get_args() + + manifest = args.manifest + logging.info(f"Validating {manifest}") + + assert manifest.is_file(), f"{manifest} does not exist" + cut_set = load_manifest_lazy(manifest) + assert isinstance(cut_set, CutSet) + + for c in cut_set: + validate_one_supervision_per_cut(c) + validate_supervision_and_cut_time_bounds(c) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/WSASR/prepare.sh b/egs/librispeech/WSASR/prepare.sh new file mode 100755 index 000000000..f6a922fde --- /dev/null +++ b/egs/librispeech/WSASR/prepare.sh @@ -0,0 +1,233 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +nj=15 +stage=-1 +stop_stage=100 + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/LibriSpeech +# You can find BOOKS.TXT, test-clean, train-clean-360, etc, inside it. +# You can download them from https://www.openslr.org/12 +# +# - $dl_dir/lm +# This directory contains the following files downloaded from +# http://www.openslr.org/resources/11 +# +# - 3-gram.pruned.1e-7.arpa.gz +# - 3-gram.pruned.1e-7.arpa +# - 4-gram.arpa.gz +# - 4-gram.arpa +# - librispeech-vocab.txt +# - librispeech-lexicon.txt +# - librispeech-lm-norm.txt.gz +# +otc_token="" +feature_type="ssl" + +dl_dir=$PWD/download +manifests_dir="data/manifests" +feature_dir="data/${feature_type}" +lang_dir="data/lang" +lm_dir="data/lm" + +perturb_speed=false + +# ssl or fbank + +. ./cmd.sh +. shared/parse_options.sh || exit 1 + +# vocab size for sentence piece models. +# It will generate data/lang_bpe_xxx, +# data/lang_bpe_yyy if the array contains xxx, yyy +vocab_sizes=( + 200 +) + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: ${dl_dir}" + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: Download LM" + mkdir -p ${dl_dir}/lm + if [ ! -e ${dl_dir}/lm/.done ]; then + ./local/download_lm.py --out-dir=${dl_dir}/lm + touch ${dl_dir}/lm/.done + fi +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/LibriSpeech, + # you can create a symlink + # + # ln -sfv /path/to/LibriSpeech $dl_dir/LibriSpeech + # + if [ ! -d $dl_dir/LibriSpeech/train-clean-100 ]; then + lhotse download librispeech --full ${dl_dir} + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare LibriSpeech manifest" + # We assume that you have downloaded the LibriSpeech corpus + # to $dl_dir/LibriSpeech + mkdir -p data/manifests + if [ ! -e data/manifests/.librispeech.done ]; then + lhotse prepare librispeech -j ${nj} \ + -p dev-clean \ + -p dev-other \ + -p test-clean \ + -p test-other \ + -p train-clean-100 "${dl_dir}/LibriSpeech" "${manifests_dir}" + touch data/manifests/.librispeech.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Compute ${feature_type} feature for librispeech (train-clean-100)" + mkdir -p "${feature_dir}" + if [ ! -e "${feature_dir}/.librispeech.done" ]; then + if [ "${feature_type}" = ssl ]; then + ./local/compute_ssl_librispeech.py + elif [ "${feature_type}" = fbank ]; then + ./local/compute_fbank_librispeech.py --perturb-speed ${perturb_speed} + else + log "Error: not supported --feature-type '${feature_type}'" + exit 2 + fi + + touch "${feature_dir}.librispeech.done" + fi + + if [ ! -e "${feature_dir}/.librispeech-validated.done" ]; then + log "Validating data/ssl for LibriSpeech" + parts=( + train-clean-100 + test-clean + test-other + dev-clean + dev-other + ) + for part in ${parts[@]}; do + python3 ./local/validate_manifest.py \ + "${feature_dir}/librispeech_cuts_${part}.jsonl.gz" + done + touch "${feature_dir}/.librispeech-validated.done" + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Prepare words.txt" + mkdir -p ${lang_dir} + + (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | + cat - $dl_dir/lm/librispeech-lexicon.txt | + sort | uniq > ${lang_dir}/lexicon.txt + + local/get_words_from_lexicon.py \ + --lang-dir ${lang_dir} \ + --otc-token ${otc_token} +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Prepare BPE based lang" + + for vocab_size in ${vocab_sizes[@]}; do + bpe_lang_dir="data/lang_bpe_${vocab_size}" + mkdir -p "${bpe_lang_dir}" + # We reuse words.txt from phone based lexicon + # so that the two can share G.pt later. + cp "${lang_dir}/words.txt" "${bpe_lang_dir}" + + if [ ! -f "${bpe_lang_dir}/transcript_words.txt" ]; then + log "Generate data for BPE training" + files=$( + find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt" + find "$dl_dir/LibriSpeech/train-clean-360" -name "*.trans.txt" + find "$dl_dir/LibriSpeech/train-other-500" -name "*.trans.txt" + ) + for f in ${files[@]}; do + cat $f | cut -d " " -f 2- + done > "${bpe_lang_dir}/transcript_words.txt" + fi + + if [ ! -f ${bpe_lang_dir}/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir ${bpe_lang_dir} \ + --vocab-size ${vocab_size} \ + --transcript ${bpe_lang_dir}/transcript_words.txt + fi + + if [ ! -f ${bpe_lang_dir}/L_disambig.pt ]; then + ./local/prepare_otc_lang_bpe.py \ + --lang-dir "${bpe_lang_dir}" \ + --otc-token "${otc_token}" + + log "Validating ${bpe_lang_dir}/lexicon.txt" + ./local/validate_bpe_lexicon.py \ + --lexicon ${bpe_lang_dir}/lexicon.txt \ + --bpe-model ${bpe_lang_dir}/bpe.model \ + --otc-token "${otc_token}" + fi + done +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Prepare G" + # We assume you have install kaldilm, if not, please install + # it using: pip install kaldilm + + mkdir -p "${lm_dir}" + if [ ! -f ${lm_dir}/G_3_gram.fst.txt ]; then + # It is used in building HLG + python3 -m kaldilm \ + --read-symbol-table="${lang_dir}/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + ${dl_dir}/lm/3-gram.pruned.1e-7.arpa > ${lm_dir}/G_3_gram.fst.txt + fi + + if [ ! -f ${lm_dir}/G_4_gram.fst.txt ]; then + # It is used for LM rescoring + python3 -m kaldilm \ + --read-symbol-table="${lang_dir}/words.txt" \ + --disambig-symbol='#0' \ + --max-order=4 \ + ${dl_dir}/lm/4-gram.arpa > ${lm_dir}/G_4_gram.fst.txt + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Compile HLG" + # Note If ./local/compile_hlg.py throws OOM, + # please switch to the following command + # + # ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + bpe_lang_dir="data/lang_bpe_${vocab_size}" + echo "LM DIR: ${lm_dir}" + ./local/compile_hlg.py \ + --lm-dir "${lm_dir}" \ + --lang-dir "${bpe_lang_dir}" + done +fi diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/train.py b/egs/mgb2/ASR/pruned_transducer_stateless5/train.py index a68702776..48468cfbd 100755 --- a/egs/mgb2/ASR/pruned_transducer_stateless5/train.py +++ b/egs/mgb2/ASR/pruned_transducer_stateless5/train.py @@ -746,7 +746,6 @@ def train_one_epoch( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(train_dl): - if batch["inputs"].shape[0] == len(batch["supervisions"]["text"]): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -966,7 +965,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1019,7 +1018,6 @@ def run(rank, world_size, args): scaler.load_state_dict(checkpoints["grad_scaler"]) for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) fix_random_seed(params.seed + epoch - 1) train_dl.sampler.set_epoch(epoch - 1) @@ -1118,7 +1116,6 @@ def scan_pessimistic_batches_for_oom( # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _, _ = compute_loss( params=params, model=model, diff --git a/egs/multi_zh-hans/ASR/prepare.sh b/egs/multi_zh-hans/ASR/prepare.sh index 5d0fe66a4..c09b9c1de 100755 --- a/egs/multi_zh-hans/ASR/prepare.sh +++ b/egs/multi_zh-hans/ASR/prepare.sh @@ -49,7 +49,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "Stage 2: Prepare THCHS-30" if [ ! -d $dl_dir/thchs30 ]; then log "Downloading THCHS-30" - lhotse download thchs30 $dl_dir/thchs30 + lhotse download thchs-30 $dl_dir/thchs30 fi if [ ! -f data/manifests/.thchs30.done ]; then diff --git a/egs/multi_zh-hans/ASR/zipformer/train.py b/egs/multi_zh-hans/ASR/zipformer/train.py index 4f2d728be..c1bbd2ee8 100755 --- a/egs/multi_zh-hans/ASR/zipformer/train.py +++ b/egs/multi_zh-hans/ASR/zipformer/train.py @@ -1164,7 +1164,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/swbd/ASR/.gitignore b/egs/swbd/ASR/.gitignore new file mode 100644 index 000000000..11d674922 --- /dev/null +++ b/egs/swbd/ASR/.gitignore @@ -0,0 +1,2 @@ +switchboard_word_alignments.tar.gz +./swb_ms98_transcriptions/ diff --git a/egs/swbd/ASR/README.md b/egs/swbd/ASR/README.md new file mode 100644 index 000000000..13b27815a --- /dev/null +++ b/egs/swbd/ASR/README.md @@ -0,0 +1,25 @@ +# Switchboard + +The Switchboard-1 Telephone Speech Corpus (LDC97S62) consists of approximately 260 hours of speech and was originally collected by Texas Instruments in 1990-1, under DARPA sponsorship. The first release of the corpus was published by NIST and distributed by the LDC in 1992-3. Since that release, a number of corrections have been made to the data files as presented on the original CD-ROM set and all copies of the first pressing have been distributed. + +Switchboard is a collection of about 2,400 two-sided telephone conversations among 543 speakers (302 male, 241 female) from all areas of the United States. A computer-driven robot operator system handled the calls, giving the caller appropriate recorded prompts, selecting and dialing another person (the callee) to take part in a conversation, introducing a topic for discussion and recording the speech from the two subjects into separate channels until the conversation was finished. About 70 topics were provided, of which about 50 were used frequently. Selection of topics and callees was constrained so that: (1) no two speakers would converse together more than once and (2) no one spoke more than once on a given topic. + +(The above introduction is from the [LDC Switchboard-1 Release 2 webpage](https://catalog.ldc.upenn.edu/LDC97S62).) + + +## Performance Record +| | eval2000 | rt03 | +|--------------------------------|------------|--------| +| `conformer_ctc` | 33.37 | 35.06 | + +See [RESULTS](/egs/swbd/ASR/RESULTS.md) for details. + +## Credit + +The training script for `conformer_ctc` comes from the LibriSpeech `conformer_ctc` recipe in icefall. + +A lot of the scripts for data processing are from the first-gen Kaldi and the ESPNet project, tailored by myself to incorporate with Lhotse and Icefall. + +Some of the scripts for text normalization are from stale pull requests of [Piotr Żelasko](https://github.com/pzelasko) and [Nagendra Goel](https://github.com/ngoel17). + +The `sclite_scoring.py` is from the GigaSpeech recipe for post processing and glm-like scoring, which is definitely not an elegant stuff to do. diff --git a/egs/swbd/ASR/RESULTS.md b/egs/swbd/ASR/RESULTS.md new file mode 100644 index 000000000..f3a22c444 --- /dev/null +++ b/egs/swbd/ASR/RESULTS.md @@ -0,0 +1,113 @@ +## Results +### Switchboard BPE training results (Conformer-CTC) + +#### 2023-09-04 + +The best WER, as of 2023-09-04, for the Switchboard is below + +Results using attention decoder are given as: + +| | eval2000-swbd | eval2000-callhome | eval2000-avg | +|--------------------------------|-----------------|---------------------|--------------| +| `conformer_ctc` | 9.48 | 17.73 | 13.67 | + +Decoding results and models can be found here: +https://huggingface.co/zrjin/icefall-asr-swbd-conformer-ctc-2023-8-26 +#### 2023-06-27 + +The best WER, as of 2023-06-27, for the Switchboard is below + +Results using HLG decoding + n-gram LM rescoring + attention decoder rescoring: + +| | eval2000 | rt03 | +|--------------------------------|------------|--------| +| `conformer_ctc` | 30.80 | 32.29 | + +Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are: + +##### eval2000 + +| ngram_lm_scale | attention_scale | +|----------------|-----------------| +| 0.9 | 1.1 | + +##### rt03 + +| ngram_lm_scale | attention_scale | +|----------------|-----------------| +| 0.9 | 1.9 | + +To reproduce the above result, use the following commands for training: + +```bash +cd egs/swbd/ASR +./prepare.sh +export CUDA_VISIBLE_DEVICES="0,1" +./conformer_ctc/train.py \ + --max-duration 120 \ + --num-workers 8 \ + --enable-musan False \ + --world-size 2 \ + --num-epochs 100 +``` + +and the following command for decoding: + +```bash +./conformer_ctc/decode.py \ + --epoch 99 \ + --avg 10 \ + --max-duration 50 +``` + +#### 2023-06-26 + +The best WER, as of 2023-06-26, for the Switchboard is below + +Results using HLG decoding + n-gram LM rescoring + attention decoder rescoring: + +| | eval2000 | rt03 | +|--------------------------------|------------|--------| +| `conformer_ctc` | 33.37 | 35.06 | + +Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are: + +##### eval2000 + +| ngram_lm_scale | attention_scale | +|----------------|-----------------| +| 0.3 | 2.5 | + +##### rt03 + +| ngram_lm_scale | attention_scale | +|----------------|-----------------| +| 0.7 | 1.3 | + +To reproduce the above result, use the following commands for training: + +```bash +cd egs/swbd/ASR +./prepare.sh +export CUDA_VISIBLE_DEVICES="0,1" +./conformer_ctc/train.py \ + --max-duration 120 \ + --num-workers 8 \ + --enable-musan False \ + --world-size 2 \ +``` + +and the following command for decoding: + +```bash +./conformer_ctc/decode.py \ + --epoch 55 \ + --avg 1 \ + --max-duration 50 +``` + +For your reference, the nbest oracle WERs are: + +| | eval2000 | rt03 | +|--------------------------------|------------|--------| +| `conformer_ctc` | 25.64 | 26.84 | diff --git a/egs/swbd/ASR/conformer_ctc/__init__.py b/egs/swbd/ASR/conformer_ctc/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py new file mode 100644 index 000000000..59d73c660 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py @@ -0,0 +1,416 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# Modified by Zengrui Jin for the SwitchBoard corpus +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class SwitchBoardAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train dataloader, + but there can be multiple test dataloaders (e.g. SwitchBoard rt03 + and eval2000). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=50, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + buffer_size=50000, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_all_cuts(self) -> CutSet: + logging.info("SwitchBoard: About to get train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "swbd_cuts_all.jsonl.gz" + ).subset(last=166844) + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("SwitchBoard: About to get dev cuts") + return load_manifest_lazy( + self.args.manifest_dir / "swbd_cuts_all.jsonl.gz" + ).subset(first=300) + + @lru_cache() + def test_eval2000_cuts(self) -> CutSet: + logging.info("SwitchBoard: About to get eval2000 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "eval2000" / "eval2000_cuts_all.jsonl.gz" + ) + + @lru_cache() + def test_rt03_cuts(self) -> CutSet: + logging.info("SwitchBoard: About to get rt03 cuts") + return load_manifest_lazy(self.args.manifest_dir / "swbd_cuts_rt03.jsonl.gz") diff --git a/egs/swbd/ASR/conformer_ctc/conformer.py b/egs/swbd/ASR/conformer_ctc/conformer.py new file mode 120000 index 000000000..d1f4209d7 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/conformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/conformer.py \ No newline at end of file diff --git a/egs/swbd/ASR/conformer_ctc/decode.py b/egs/swbd/ASR/conformer_ctc/decode.py new file mode 100755 index 000000000..2bbade374 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/decode.py @@ -0,0 +1,853 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang) +# Modified by Zengrui Jin for the SwitchBoard corpus +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import SwitchBoardAsrDataModule +from conformer import Conformer + +from sclite_scoring import asr_text_post_processing + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint +from icefall.decode import ( + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_attention_decoder, + rescore_with_n_best_list, + rescore_with_rnn_lm, + rescore_with_whole_lattice, +) +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.rnn_lm.model import RnnLmModel +from icefall.utils import ( + AttributeDict, + get_texts, + load_averaged_model, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=98, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=55, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--method", + type=str, + default="attention-decoder", + help="""Decoding method. + Supported values are: + - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (1) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (2) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (3) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (4) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + - (5) attention-decoder. Extract n paths from the LM rescored + lattice, the path with the highest score is the decoding result. + - (6) rnn-lm. Rescoring with attention-decoder and RNN LM. We assume + you have trained an RNN LM using ./rnn_lm/train.py + - (7) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="The lang dir", + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The n-gram LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + parser.add_argument( + "--rnn-lm-exp-dir", + type=str, + default="rnn_lm/exp", + help="""Used only when --method is rnn-lm. + It specifies the path to RNN LM exp dir. + """, + ) + + parser.add_argument( + "--rnn-lm-epoch", + type=int, + default=7, + help="""Used only when --method is rnn-lm. + It specifies the checkpoint to use. + """, + ) + + parser.add_argument( + "--rnn-lm-avg", + type=int, + default=2, + help="""Used only when --method is rnn-lm. + It specifies the number of checkpoints to average. + """, + ) + + parser.add_argument( + "--rnn-lm-embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--rnn-lm-hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--rnn-lm-num-layers", + type=int, + default=4, + help="Number of RNN layers the model", + ) + parser.add_argument( + "--rnn-lm-tie-weights", + type=str2bool, + default=False, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + # parameters for conformer + "subsampling_factor": 4, + "vgg_frontend": False, + "use_feat_batchnorm": True, + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "num_decoder_layers": 6, + # parameters for decoding + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + "env_info": get_env_info(), + } + ) + return params + + +def post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = asr_text_post_processing(" ".join(ref)).split() + new_hyp = asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + rnn_lm_model: Optional[nn.Module], + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + + - params.method is "1best", it uses 1best decoding without LM rescoring. + - params.method is "nbest", it uses nbest decoding without LM rescoring. + - params.method is "nbest-rescoring", it uses nbest LM rescoring. + - params.method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + rnn_lm_model: + The neural model for RNN LM. + HLG: + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + sos_id: + The token ID of the SOS. + eos_id: + The token ID of the EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + if HLG is not None: + device = HLG.device + else: + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + + nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) + # nnet_output is (N, T, C) + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"] // params.subsampling_factor, + supervisions["num_frames"] // params.subsampling_factor, + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + + if params.method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa + return {key: hyps} + + if params.method in ["1best", "nbest"]: + if params.method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no_rescore" + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + return {key: hyps} + + assert params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder", + "rnn-lm", + ] + + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if params.method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + elif params.method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + elif params.method == "attention-decoder": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + + best_path_dict = rescore_with_attention_decoder( + lattice=rescored_lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + nbest_scale=params.nbest_scale, + ) + elif params.method == "rnn-lm": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + + best_path_dict = rescore_with_rnn_lm( + lattice=rescored_lattice, + num_paths=params.num_paths, + rnn_lm_model=rnn_lm_model, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + blank_id=0, + nbest_scale=params.nbest_scale, + ) + else: + assert False, f"Unsupported decoding method: {params.method}" + + ans = dict() + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + ans = None + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + rnn_lm_model: Optional[nn.Module], + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + rnn_lm_model: + The neural model for RNN LM. + HLG: + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + word_table: + It is the word symbol table. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + rnn_lm_model=rnn_lm_model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + if hyps_dict is not None: + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[lm_scale].extend(this_batch) + else: + assert len(results) > 0, "It should not decode to empty in the first batch!" + this_batch = [] + hyp_words = [] + for ref_text in texts: + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + for lm_scale in results.keys(): + results[lm_scale].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[int], List[int]]]], +): + if params.method in ("attention-decoder", "rnn-lm"): + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True + if test_set_name == "test-eval2000": + subsets = {"callhome": "en_", "swbd": "sw_", "avg": "*"} + elif test_set_name == "test-rt03": + subsets = {"fisher": "fsh_", "swbd": "sw_", "avg": "*"} + else: + raise NotImplementedError(f"No implementation for testset {test_set_name}") + for subset, prefix in subsets.items(): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.exp_dir / f"recogs-{test_set_name}-{subset}-{key}.txt" + results = post_processing(results) + results = ( + sorted(list(filter(lambda x: x[0].startswith(prefix), results))) + if subset != "avg" + else sorted(results) + ) + store_transcripts(filename=recog_path, texts=results) + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.exp_dir / f"errs-{test_set_name}-{subset}-{key}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, + f"{test_set_name}-{subset}-{key}", + results, + enable_log=enable_log, + sclite_mode=True, + ) + test_set_wers[key] = wer + + if enable_log: + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"wer-summary-{test_set_name}-{subset}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}-{}, WER of different settings are:\n".format( + test_set_name, subset + ) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + SwitchBoardAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") + logging.info("Decoding started") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + sos_id = graph_compiler.sos_id + eos_id = graph_compiler.eos_id + + params.num_classes = num_classes + params.sos_id = sos_id + params.eos_id = eos_id + + if params.method == "ctc-decoding": + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder", + "rnn-lm", + ): + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + G = k2.Fsa.from_dict(d) + + if params.method in [ + "whole-lattice-rescoring", + "attention-decoder", + "rnn-lm", + ]: + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + model = load_averaged_model( + params.exp_dir, model, params.epoch, params.avg, device + ) + + model.to(device) + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + rnn_lm_model = None + if params.method == "rnn-lm": + rnn_lm_model = RnnLmModel( + vocab_size=params.num_classes, + embedding_dim=params.rnn_lm_embedding_dim, + hidden_dim=params.rnn_lm_hidden_dim, + num_layers=params.rnn_lm_num_layers, + tie_weights=params.rnn_lm_tie_weights, + ) + if params.rnn_lm_avg == 1: + load_checkpoint( + f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", + rnn_lm_model, + ) + rnn_lm_model.to(device) + else: + rnn_lm_model = load_averaged_model( + params.rnn_lm_exp_dir, + rnn_lm_model, + params.rnn_lm_epoch, + params.rnn_lm_avg, + device, + ) + rnn_lm_model.eval() + + # we need cut ids to display recognition results. + args.return_cuts = True + switchboard = SwitchBoardAsrDataModule(args) + + test_eval2000_cuts = switchboard.test_eval2000_cuts().trim_to_supervisions( + keep_all_channels=True + ) + # test_rt03_cuts = switchboard.test_rt03_cuts().trim_to_supervisions( + # keep_all_channels=True + # ) + + test_eval2000_dl = switchboard.test_dataloaders(test_eval2000_cuts) + # test_rt03_dl = switchboard.test_dataloaders(test_rt03_cuts) + + # test_sets = ["test-eval2000", "test-rt03"] + # test_dl = [test_eval2000_dl, test_rt03_dl] + test_sets = ["test-eval2000"] + test_dl = [test_eval2000_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + rnn_lm_model=rnn_lm_model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + save_results(params=params, test_set_name=test_set, results_dict=results_dict) + + logging.info("Done!") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/conformer_ctc/export.py b/egs/swbd/ASR/conformer_ctc/export.py new file mode 100755 index 000000000..1bb6277ad --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/export.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. + +import argparse +import logging +from pathlib import Path + +import torch +from conformer import Conformer + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.lexicon import Lexicon +from icefall.utils import AttributeDict, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=98, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=55, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""It contains language related input files such as "lexicon.txt" + """, + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=True, + help="""True to save a model after applying torch.jit.script. + """, + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "feature_dim": 80, + "subsampling_factor": 4, + "use_feat_batchnorm": True, + "attention_dim": 512, + "nhead": 8, + "num_decoder_layers": 6, + } + ) + return params + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + params = get_params() + params.update(vars(args)) + + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=False, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + model.to(device) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames)) + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/swbd/ASR/conformer_ctc/label_smoothing.py b/egs/swbd/ASR/conformer_ctc/label_smoothing.py new file mode 120000 index 000000000..e9d239fff --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/label_smoothing.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/swbd/ASR/conformer_ctc/pretrained.py b/egs/swbd/ASR/conformer_ctc/pretrained.py new file mode 120000 index 000000000..526bc9678 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/pretrained.py \ No newline at end of file diff --git a/egs/swbd/ASR/conformer_ctc/sclite_scoring.py b/egs/swbd/ASR/conformer_ctc/sclite_scoring.py new file mode 100755 index 000000000..0383c4d71 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/sclite_scoring.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 +# Copyright 2021 Jiayu Du +# Copyright 2022 Johns Hopkins University (Author: Guanbo Wang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os + +conversational_filler = [ + "UH", + "UHH", + "UM", + "EH", + "MM", + "HM", + "AH", + "HUH", + "HA", + "ER", + "OOF", + "HEE", + "ACH", + "EEE", + "EW", + "MHM", + "HUM", + "AW", + "OH", + "HMM", + "UMM", +] +unk_tags = ["", ""] +switchboard_garbage_utterance_tags = [ + "[LAUGHTER]", + "[NOISE]", + "[VOCALIZED-NOISE]", + "[SILENCE]", +] +non_scoring_words = ( + conversational_filler + unk_tags + switchboard_garbage_utterance_tags +) + + +def asr_text_post_processing(text: str) -> str: + # 1. convert to uppercase + text = text.upper() + + # 2. remove non-scoring words from evaluation + remaining_words = [] + text_split = text.split() + word_to_skip = 0 + for idx, word in enumerate(text_split): + if word_to_skip > 0: + word_to_skip -= 1 + continue + if word in non_scoring_words: + continue + elif word == "CANCELLED": + remaining_words.append("CANCELED") + continue + elif word == "AIRFLOW": + remaining_words.append("AIR") + remaining_words.append("FLOW") + continue + elif word == "PHD": + remaining_words.append("P") + remaining_words.append("H") + remaining_words.append("D") + continue + elif word == "UCLA": + remaining_words.append("U") + remaining_words.append("C") + remaining_words.append("L") + remaining_words.append("A") + continue + elif word == "ONTO": + remaining_words.append("ON") + remaining_words.append("TO") + continue + elif word == "DAY": + try: + if text_split[idx + 1] == "CARE": + remaining_words.append("DAYCARE") + word_to_skip = 1 + except: + remaining_words.append(word) + continue + remaining_words.append(word) + + return " ".join(remaining_words) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="This script evaluates GigaSpeech ASR result via" + "SCTK's tool sclite" + ) + parser.add_argument( + "ref", + type=str, + help="sclite's standard transcription(trn) reference file", + ) + parser.add_argument( + "hyp", + type=str, + help="sclite's standard transcription(trn) hypothesis file", + ) + parser.add_argument( + "work_dir", + type=str, + help="working dir", + ) + args = parser.parse_args() + + if not os.path.isdir(args.work_dir): + os.mkdir(args.work_dir) + + REF = os.path.join(args.work_dir, "REF") + HYP = os.path.join(args.work_dir, "HYP") + RESULT = os.path.join(args.work_dir, "RESULT") + + for io in [(args.ref, REF), (args.hyp, HYP)]: + with open(io[0], "r", encoding="utf8") as fi: + with open(io[1], "w+", encoding="utf8") as fo: + for line in fi: + line = line.strip() + if line: + cols = line.split() + text = asr_text_post_processing(" ".join(cols[0:-1])) + uttid_field = cols[-1] + print(f"{text} {uttid_field}", file=fo) + + # GigaSpeech's uttid comforms to swb + os.system(f"sclite -r {REF} trn -h {HYP} trn -i swb | tee {RESULT}") diff --git a/egs/swbd/ASR/conformer_ctc/subsampling.py b/egs/swbd/ASR/conformer_ctc/subsampling.py new file mode 120000 index 000000000..16354dc73 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/subsampling.py \ No newline at end of file diff --git a/egs/swbd/ASR/conformer_ctc/test_label_smoothing.py b/egs/swbd/ASR/conformer_ctc/test_label_smoothing.py new file mode 100755 index 000000000..5d4438fd1 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/test_label_smoothing.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from distutils.version import LooseVersion + +import torch +from label_smoothing import LabelSmoothingLoss + +torch_ver = LooseVersion(torch.__version__) + + +def test_with_torch_label_smoothing_loss(): + if torch_ver < LooseVersion("1.10.0"): + print(f"Current torch version: {torch_ver}") + print("Please use torch >= 1.10 to run this test - skipping") + return + torch.manual_seed(20211105) + x = torch.rand(20, 30, 5000) + tgt = torch.randint(low=-1, high=x.size(-1), size=x.shape[:2]) + for reduction in ["none", "sum", "mean"]: + custom_loss_func = LabelSmoothingLoss( + ignore_index=-1, label_smoothing=0.1, reduction=reduction + ) + custom_loss = custom_loss_func(x, tgt) + + torch_loss_func = torch.nn.CrossEntropyLoss( + ignore_index=-1, reduction=reduction, label_smoothing=0.1 + ) + torch_loss = torch_loss_func(x.reshape(-1, x.size(-1)), tgt.reshape(-1)) + assert torch.allclose(custom_loss, torch_loss) + + +def main(): + test_with_torch_label_smoothing_loss() + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/conformer_ctc/test_subsampling.py b/egs/swbd/ASR/conformer_ctc/test_subsampling.py new file mode 100755 index 000000000..81fa234dd --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/test_subsampling.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from subsampling import Conv2dSubsampling, VggSubsampling + + +def test_conv2d_subsampling(): + N = 3 + odim = 2 + + for T in range(7, 19): + for idim in range(7, 20): + model = Conv2dSubsampling(idim=idim, odim=odim) + x = torch.empty(N, T, idim) + y = model(x) + assert y.shape[0] == N + assert y.shape[1] == ((T - 1) // 2 - 1) // 2 + assert y.shape[2] == odim + + +def test_vgg_subsampling(): + N = 3 + odim = 2 + + for T in range(7, 19): + for idim in range(7, 20): + model = VggSubsampling(idim=idim, odim=odim) + x = torch.empty(N, T, idim) + y = model(x) + assert y.shape[0] == N + assert y.shape[1] == ((T - 1) // 2 - 1) // 2 + assert y.shape[2] == odim diff --git a/egs/swbd/ASR/conformer_ctc/test_transformer.py b/egs/swbd/ASR/conformer_ctc/test_transformer.py new file mode 120000 index 000000000..8b0990ec6 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/test_transformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/test_transformer.py \ No newline at end of file diff --git a/egs/swbd/ASR/conformer_ctc/train.py b/egs/swbd/ASR/conformer_ctc/train.py new file mode 100755 index 000000000..7f1eebbcf --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/train.py @@ -0,0 +1,814 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang +# Mingshuang Luo) +# Modified by Zengrui Jin for the SwitchBoard corpus +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + export CUDA_VISIBLE_DEVICES="0,1,2,3" + ./conformer_ctc/train.py \ + --exp-dir ./conformer_ctc/exp \ + --world-size 4 \ + --max-duration 200 \ + --num-epochs 20 +""" + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Optional, Tuple + +import k2 +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import SwitchBoardAsrDataModule +from conformer import Conformer +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter +from transformer import Noam + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.graph_compiler import CtcTrainingGraphCompiler +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + MetricsTracker, + encode_supervisions, + setup_logger, + str2bool, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=98, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + conformer_ctc/exp/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--att-rate", + type=float, + default=0.8, + help="""The attention rate. + The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss + """, + ) + + parser.add_argument( + "--num-decoder-layers", + type=int, + default=6, + help="""Number of decoder layer of transformer decoder. + Setting this to 0 will not create the decoder at all (pure CTC model) + """, + ) + + parser.add_argument( + "--lr-factor", + type=float, + default=5.0, + help="The lr_factor for Noam optimizer", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - use_feat_batchnorm: Normalization for the input features, can be a + boolean indicating whether to do batch + normalization, or a float which means just scaling + the input features with this float value. + If given a float value, we will remove batchnorm + layer in `ConvolutionModule` as well. + + - attention_dim: Hidden dim for multi-head attention model. + + - head: Number of heads of multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - beam_size: It is used in k2.ctc_loss + + - reduction: It is used in k2.ctc_loss + + - use_double_scores: It is used in k2.ctc_loss + + - weight_decay: The weight_decay for the optimizer. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "use_feat_batchnorm": True, + "attention_dim": 512, + "nhead": 8, + # parameters for loss + "beam_size": 10, + "reduction": "sum", + "use_double_scores": True, + # parameters for Noam + "weight_decay": 1e-6, + "warm_step": 80000, + "env_info": get_env_info(), + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch <= 0: + return + + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: nn.Module, + batch: dict, + graph_compiler: BpeCtcTrainingGraphCompiler, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + graph_compiler: + It is used to build a decoding graph from a ctc topo and training + transcript. The training transcript is contained in the given `batch`, + while the ctc topo is built when this compiler is instantiated. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = graph_compiler.device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + with torch.set_grad_enabled(is_training): + nnet_output, encoder_memory, memory_mask = model(feature, supervisions) + # nnet_output is (N, T, C) + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=params.subsampling_factor + ) + + if isinstance(graph_compiler, BpeCtcTrainingGraphCompiler): + # Works with a BPE model + token_ids = graph_compiler.texts_to_ids(texts) + decoding_graph = graph_compiler.compile(token_ids) + elif isinstance(graph_compiler, CtcTrainingGraphCompiler): + # Works with a phone lexicon + decoding_graph = graph_compiler.compile(texts) + else: + raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") + + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + + if params.att_rate != 0.0: + with torch.set_grad_enabled(is_training): + mmodel = model.module if hasattr(model, "module") else model + # Note: We need to generate an unsorted version of token_ids + # `encode_supervisions()` called above sorts text, but + # encoder_memory and memory_mask are not sorted, so we + # use an unsorted version `supervisions["text"]` to regenerate + # the token_ids + # + # See https://github.com/k2-fsa/icefall/issues/97 + # for more details + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) + att_loss = mmodel.decoder_forward( + encoder_memory, + memory_mask, + token_ids=unsorted_token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) + loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss + else: + loss = ctc_loss + att_loss = torch.tensor([0]) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + info["frames"] = supervision_segments[:, 2].sum().item() + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.att_rate != 0.0: + info["att_loss"] = att_loss.detach().cpu().item() + + info["loss"] = loss.detach().cpu().item() + + # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa + info["utterances"] = feature.size(0) + # averaged input duration in frames over utterances + info["utt_duration"] = supervisions["num_frames"].sum().item() + # averaged padding proportion over utterances + info["utt_pad_proportion"] = ( + ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item() + ) + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + graph_compiler: BpeCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + graph_compiler: + It is used to convert transcripts to FSAs. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + optimizer.zero_grad() + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + + if batch_idx % params.log_interval == 0: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" + ) + + if batch_idx % params.log_interval == 0: + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + if "lang_bpe" in str(params.lang_dir): + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + elif "lang_phone" in str(params.lang_dir): + assert params.att_rate == 0, ( + "Attention decoder training does not support phone lang dirs " + "at this time due to a missing symbol. Set --att-rate=0 " + "for pure CTC training when using a phone-based lang dir." + ) + assert params.num_decoder_layers == 0, ( + "Attention decoder training does not support phone lang dirs " + "at this time due to a missing symbol. " + "Set --num-decoder-layers=0 for pure CTC training when using " + "a phone-based lang dir." + ) + graph_compiler = CtcTrainingGraphCompiler( + lexicon, + device=device, + ) + # Manually add the sos/eos ID with their default values + # from the BPE recipe which we're adapting here. + graph_compiler.sos_id = 1 + graph_compiler.eos_id = 1 + else: + raise ValueError( + f"Unsupported type of lang dir (we expected it to have " + f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}" + ) + + logging.info("About to create model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=False, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + model = DDP(model, device_ids=[rank]) + + optimizer = Noam( + model.parameters(), + model_size=params.attention_dim, + factor=params.lr_factor, + warm_step=params.warm_step, + weight_decay=params.weight_decay, + ) + + if checkpoints: + optimizer.load_state_dict(checkpoints["optimizer"]) + + switchboard = SwitchBoardAsrDataModule(args) + + train_cuts = switchboard.train_all_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + train_dl = switchboard.train_dataloaders(train_cuts) + + valid_cuts = switchboard.dev_cuts() + valid_dl = switchboard.valid_dataloaders(valid_cuts) + + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + + for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) + train_dl.sampler.set_epoch(epoch) + + cur_lr = optimizer._rate + if tb_writer is not None: + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + if rank == 0: + logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + ) + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: nn.Module, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + optimizer.zero_grad() + loss, _ = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + SwitchBoardAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/conformer_ctc/transformer.py b/egs/swbd/ASR/conformer_ctc/transformer.py new file mode 120000 index 000000000..1c3f43fcf --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/transformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/transformer.py \ No newline at end of file diff --git a/egs/swbd/ASR/local/compile_hlg.py b/egs/swbd/ASR/local/compile_hlg.py new file mode 120000 index 000000000..471aa7fb4 --- /dev/null +++ b/egs/swbd/ASR/local/compile_hlg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_hlg.py \ No newline at end of file diff --git a/egs/swbd/ASR/local/compile_lg.py b/egs/swbd/ASR/local/compile_lg.py new file mode 120000 index 000000000..462d6d3fb --- /dev/null +++ b/egs/swbd/ASR/local/compile_lg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/swbd/ASR/local/compute_fbank_eval2000.py b/egs/swbd/ASR/local/compute_fbank_eval2000.py new file mode 100755 index 000000000..d446e8ff3 --- /dev/null +++ b/egs/swbd/ASR/local/compute_fbank_eval2000.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# Modified 2023 The Chinese University of Hong Kong (author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the SwitchBoard dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import sentencepiece as spm +import torch +from filter_cuts import filter_cuts +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor, str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to the bpe.model. If not None, we will remove short and + long utterances before extracting features""", + ) + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", + ) + + return parser.parse_args() + + +def compute_fbank_switchboard( + dir_name: str, + bpe_model: Optional[str] = None, + dataset: Optional[str] = None, + perturb_speed: Optional[bool] = True, +): + src_dir = Path(f"data/manifests/{dir_name}") + output_dir = Path(f"data/fbank/{dir_name}") + num_jobs = min(1, os.cpu_count()) + num_mel_bins = 80 + + if bpe_model: + logging.info(f"Loading {bpe_model}") + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + + if dataset is None: + dataset_parts = ("all",) + else: + dataset_parts = dataset.split(" ", -1) + + prefix = dir_name + suffix = "jsonl.gz" + manifests = { + "eval2000": "data/manifests/eval2000/eval2000_cuts_all_trimmed.jsonl.gz", + } + assert manifests is not None + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, sampling_rate=16000)) + + with get_executor() as ex: # Initialize the executor only once. + partition = "all" + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + print(cuts_filename) + if (output_dir / cuts_filename).is_file(): + logging.info(f"{prefix} already exists - skipping.") + return + logging.info(f"Processing {prefix}") + cut_set = CutSet.from_file(manifests[prefix]).resample(16000) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set = cut_set.trim_to_supervisions(keep_overlapping=False) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + compute_fbank_switchboard( + dir_name="eval2000", + bpe_model=args.bpe_model, + dataset=args.dataset, + perturb_speed=args.perturb_speed, + ) diff --git a/egs/swbd/ASR/local/compute_fbank_swbd.py b/egs/swbd/ASR/local/compute_fbank_swbd.py new file mode 100755 index 000000000..dd82220c0 --- /dev/null +++ b/egs/swbd/ASR/local/compute_fbank_swbd.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# Modified 2023 The Chinese University of Hong Kong (author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the SwitchBoard dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import sentencepiece as spm +import torch +from filter_cuts import filter_cuts +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor, str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to the bpe.model. If not None, we will remove short and + long utterances before extracting features""", + ) + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", + ) + + parser.add_argument( + "--split-index", + type=int, + required=True, + ) + + return parser.parse_args() + + +def compute_fbank_switchboard( + dir_name: str, + split_index: int, + bpe_model: Optional[str] = None, + dataset: Optional[str] = None, + perturb_speed: Optional[bool] = True, +): + src_dir = Path(f"data/manifests/{dir_name}") + output_dir = Path(f"data/fbank/{dir_name}_split16") + num_jobs = min(1, os.cpu_count()) + num_mel_bins = 80 + + if bpe_model: + logging.info(f"Loading {bpe_model}") + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + + if dataset is None: + dataset_parts = ("all",) + else: + dataset_parts = dataset.split(" ", -1) + + prefix = dir_name + suffix = "jsonl.gz" + split_dir = Path("data/manifests/swbd_split16/") + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, sampling_rate=16000)) + + with get_executor() as ex: # Initialize the executor only once. + partition = "all" + cuts_filename = ( + f"{prefix}_cuts_{partition}.{str(split_index).zfill(2)}.{suffix}" + ) + print(cuts_filename) + if (output_dir / cuts_filename).is_file(): + logging.info(f"{prefix} already exists - skipping.") + return + logging.info(f"Processing {prefix}") + cut_set = ( + CutSet.from_file( + split_dir + / f"swbd_train_all_trimmed.{str(split_index).zfill(2)}.jsonl.gz" + ) + .resample(16000) + .to_eager() + .filter(lambda c: c.duration > 2.0) + ) + + if bpe_model: + cut_set = filter_cuts(cut_set, sp) + if perturb_speed: + logging.info(f"Doing speed perturb") + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}_{str(split_index).zfill(2)}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set = cut_set.trim_to_supervisions( + keep_overlapping=False, + min_duration=None, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + + compute_fbank_switchboard( + dir_name="swbd", + split_index=args.split_index, + bpe_model=args.bpe_model, + dataset=args.dataset, + perturb_speed=args.perturb_speed, + ) diff --git a/egs/swbd/ASR/local/convert_transcript_words_to_tokens.py b/egs/swbd/ASR/local/convert_transcript_words_to_tokens.py new file mode 100755 index 000000000..a8d5117c9 --- /dev/null +++ b/egs/swbd/ASR/local/convert_transcript_words_to_tokens.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 + +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +""" +Convert a transcript file containing words to a corpus file containing tokens +for LM training with the help of a lexicon. + +If the lexicon contains phones, the resulting LM will be a phone LM; If the +lexicon contains word pieces, the resulting LM will be a word piece LM. + +If a word has multiple pronunciations, the one that appears first in the lexicon +is kept; others are removed. + +If the input transcript is: + + hello zoo world hello + world zoo + foo zoo world hellO + +and if the lexicon is + + SPN + hello h e l l o 2 + hello h e l l o + world w o r l d + zoo z o o + +Then the output is + + h e l l o 2 z o o w o r l d h e l l o 2 + w o r l d z o o + SPN z o o w o r l d SPN +""" + +import argparse +from pathlib import Path +from typing import Dict, List + +from generate_unique_lexicon import filter_multiple_pronunications + +from icefall.lexicon import read_lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--transcript", + type=str, + help="The input transcript file." + "We assume that the transcript file consists of " + "lines. Each line consists of space separated words.", + ) + parser.add_argument("--lexicon", type=str, help="The input lexicon file.") + parser.add_argument("--oov", type=str, default="", help="The OOV word.") + + return parser.parse_args() + + +def process_line(lexicon: Dict[str, List[str]], line: str, oov_token: str) -> None: + """ + Args: + lexicon: + A dict containing pronunciations. Its keys are words and values + are pronunciations (i.e., tokens). + line: + A line of transcript consisting of space(s) separated words. + oov_token: + The pronunciation of the oov word if a word in `line` is not present + in the lexicon. + Returns: + Return None. + """ + s = "" + words = line.strip().split() + for i, w in enumerate(words): + tokens = lexicon.get(w, oov_token) + s += " ".join(tokens) + s += " " + print(s.strip()) + + +def main(): + args = get_args() + assert Path(args.lexicon).is_file() + assert Path(args.transcript).is_file() + assert len(args.oov) > 0 + + # Only the first pronunciation of a word is kept + lexicon = filter_multiple_pronunications(read_lexicon(args.lexicon)) + + lexicon = dict(lexicon) + + assert args.oov in lexicon + + oov_token = lexicon[args.oov] + + with open(args.transcript) as f: + for line in f: + process_line(lexicon=lexicon, line=line, oov_token=oov_token) + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/local/dict.patch b/egs/swbd/ASR/local/dict.patch new file mode 100644 index 000000000..12c63d612 --- /dev/null +++ b/egs/swbd/ASR/local/dict.patch @@ -0,0 +1,380 @@ +1d0 +< file: $SWB/data/dictionary/sw-ms98-dict.text +8645a8646 +> uh-hum ah m hh ah m +9006c9007 +< April ey p r ih l +--- +> April ey p r ax l +9144d9144 +< B ay zh aa n iy z +9261c9261 +< Battle b ae t el +--- +> Battle b ae t ax l +10014a10015 +> Chevy sh eh v iy +10211a10213 +> Colorado k ao l ax r aa d ow +10212a10215 +> Colorado' k ao l ax r aa d ow z +10370c10373 +< Creek k r ih k +--- +> Creek k r iy k +10889a10893 +> Eleven ax l eh v ih n +10951c10955 +< Erie ih r iy +--- +> Erie iy r iy +11183c11187 +< Forever f ax r eh v er +--- +> Forever f er eh v er +11231a11236 +> Friday f r ay d iy +11744a11750 +> History hh ih s t r iy +12004a12011,12012 +> Israel ih z r ih l +> Israel's ih z r ih l z +12573a12582 +> Lincoln l ih ng k ih n +12574a12584 +> Lincolns l ih ng k ih n z +13268c13278 +< NAACP eh ey ey s iy p iy +--- +> NAACP eh n ey ey s iy p iy +13286c13296 +< NIT eh ay t iy +--- +> NIT eh n ay t iy +13292c13302 +< NTSC eh t iy eh s s iy +--- +> NTSC eh n t iy eh s s iy +14058a14069 +> Quarter k ow r t er +14059a14071 +> Quarterback k ow r t er b ae k +14060a14073 +> Quarters k ow r t er z +14569a14583 +> Science s ay n s +15087a15102 +> Sunday s ah n d iy +15088a15104 +> Sunday's s ah n d iy z +15089a15106 +> Sundays s ah n d iy z +15290,15291c15307,15308 +< Texan t eh k sh ih n +< Texan's t eh k sh ih n s +--- +> Texan t eh k s ih n +> Texan's t eh k s ih n s +15335a15353 +> Thousands th aw z ih n z +15739c15757 +< Waco w ae k ow +--- +> Waco w ey k ow +15841a15860 +> Weekends w iy k eh n z +16782a16802 +> acceptable eh k s eh p ax b ax l +16833a16854 +> accounting ax k aw n ih ng +16948a16970 +> address ax d r eh s +17281a17304 +> already aa r d iy +17315a17339 +> am m +17709a17734 +> asked ae s t +17847a17873 +> attorney ih t er n iy +17919a17946 +> autopilot ao t ow p ay l ih t +17960a17988 +> awfully ao f l iy +18221a18250 +> basketball b ae s k ax b ao l +18222a18252 +> basketball's b ae s k ax b ao l z +18302a18333 +> become b ah k ah m +18303a18335 +> becomes b iy k ah m z +18344a18377 +> began b ax g en n +18817c18850 +< bottle b aa t el +--- +> bottle b aa t ax l +19332,19333c19365,19367 +< camera's k ae m ax r ax z +< cameras k ae m ax r ax z +--- +> camera k ae m r ax +> camera's k ae m r ax z +> cameras k ae m r ax z +19411a19446 +> capital k ae p ax l +19505a19541 +> carrying k ae r ih ng +20316a20353,20354 +> combination k aa m ih n ey sh ih n +> combinations k aa m ih n ey sh ih n z +20831a20870 +> contracts k aa n t r ae k s +21010a21050 +> costs k ao s +21062a21103 +> county k aw n iy +21371a21413 +> cultural k ao l ch ax r ax l +21372a21415 +> culturally k ao l ch ax r ax l iy +21373a21417 +> culture k ao l ch er +21375a21420 +> cultures k ao l ch er z +21543a21589 +> data d ey t ax +22097a22144 +> differently d ih f ax r ih n t l iy +22972a23020 +> effects ax f eh k t s +23016a23065 +> election ax l eh k sh ih n +23018a23068 +> elections ax l eh k sh ih n z +23052a23103 +> eleven ax l eh v ih n +23242a23294 +> enjoyable ae n jh oy ax b ax l +23248a23301 +> enjoys ae n jh oy z +23293a23347 +> entire ih n t ay r +23295a23350,23351 +> entirely ih n t ay r l iy +> entirety ih n t ay r t iy +23745a23802 +> extra eh k s t er +23818a23876 +> facts f ae k s +24508c24566 +< forever f ax r eh v er +--- +> forever f er eh v er +24514c24572 +< forget f ow r g eh t +--- +> forget f er r g eh t +24521a24580 +> forgot f er r g aa t +24522a24582 +> forgotten f er r g aa t ax n +24563a24624 +> forward f ow er d +24680a24742 +> frightening f r ay t n ih ng +24742a24805 +> full-time f ax l t ay m +24862a24926 +> garage g r aa jh +25218a25283 +> grandmother g r ae m ah dh er +25790a25856 +> heavily hh eh v ax l iy +25949a26016 +> history hh ih s t r iy +26038a26106 +> honestly aa n ax s t l iy +26039a26108 +> honesty aa n ax s t iy +26099a26169 +> horror hh ow r +26155a26226 +> houses hh aw z ih z +26184c26255 +< huh-uh hh ah hh ah +--- +> huh-uh ah hh ah +26189c26260 +< hum-um hh m hh m +--- +> hum-um ah m hh ah m +26236a26308 +> hunting hh ah n ih ng +26307a26380,26381 +> ideal ay d iy l +> idealist ay d iy l ih s t +26369a26444 +> imagine m ae jh ih n +26628a26704 +> individuals ih n d ih v ih jh ax l z +26968a27045 +> interest ih n t r ih s t +27184a27262 +> it'd ih d +27702a27781 +> lead l iy d +28378a28458 +> mandatory m ae n d ih t ow r iy +28885a28966 +> minute m ih n ih t +29167a29249 +> mountains m aw t n z +29317a29400 +> mysteries m ih s t r iy z +29318a29402 +> mystery m ih s t r iy +29470a29555 +> nervous n er v ih s +29578,29580c29663,29665 +< nobody n ow b aa d iy +< nobody'll n ow b aa d iy l +< nobody's n ow b aa d iy z +--- +> nobody n ow b ah d iy +> nobody'll n ow b ah d iy l +> nobody's n ow b ah d iy z +29712a29798 +> nuclear n uw k l iy r +29938a30025 +> onto aa n t ax +30051a30139 +> originally ax r ih jh ax l iy +30507a30596 +> particularly p er t ih k y ax l iy +30755a30845 +> perfectly p er f ih k l iy +30820a30911 +> personally p er s n ax l iy +30915a31007 +> physically f ih z ih k l iy +30986a31079 +> pilot p ay l ih t +30987a31081 +> pilot's p ay l ih t s +31227a31322 +> police p l iy s +31513a31609 +> prefer p er f er +31553a31650 +> prepare p r ax p ey r +31578a31676 +> prescription p er s k r ih p sh ih n +31579a31678 +> prescriptions p er s k r ih p sh ih n z +31770a31870 +> products p r aa d ax k s +31821a31922 +> projects p r aa jh eh k s +31908a32010 +> protect p er t eh k t +31909a32012 +> protected p er t eh k t ih d +31911a32015 +> protection p er t eh k sh ih n +31914a32019 +> protection p er t eh k t ih v +32149a32255 +> quarter k ow r t er +32414a32521 +> read r iy d +32785a32893 +> rehabilitation r iy ax b ih l ih t ey sh ih n +33150a33259 +> resource r ih s ow r s +33151a33261 +> resources r iy s ow r s ih z +33539c33649 +< roots r uh t s +--- +> roots r uw t s +33929a34040 +> science s ay n s +34315a34427 +> seventy s eh v ih n iy +34319,34320c34431,34432 +< severe s ax v iy r +< severely s ax v iy r l iy +--- +> severe s ih v iy r +> severely s ih v iy r l iy +35060a35173 +> software s ao f w ey r +35083a35197 +> solid s ao l ih d +35084a35199 +> solidly s ao l ih d l iy +35750a35866 +> stood s t ih d +35854a35971 +> strictly s t r ih k l iy +35889c36006 +< stronger s t r ao ng er +--- +> stronger s t r ao ng g er +36192a36310,36311 +> supposed s p ow z +> supposed s p ow s +36510a36630 +> tastes t ey s +36856a36977 +> thoroughly th er r l iy +36866a36988 +> thousands th aw z ih n z +37081c37203 +< toots t uh t s +--- +> toots t uw t s +37157a37280 +> toward t w ow r d +37158a37282 +> towards t w ow r d z +37564a37689 +> twenties t w eh n iy z +37565a37691 +> twentieth t w eh n iy ih th +37637a37764 +> unacceptable ah n ae k s eh p ax b ax l +37728a37856 +> understand ah n d er s t ae n +37860a37989 +> unless ih n l eh s +38040a38170 +> use y uw z +38049a38180 +> uses y uw z ih z +38125a38257 +> various v ah r iy ih s +38202a38335 +> versus v er s ih z +38381c38514 +< wacko w ae k ow +--- +> wacko w ey k ow +38455c38588 +< wanna w aa n ax +--- +> wanna w ah n ax +38675c38808 +< whatnot w ah t n aa t +--- +> whatnot w aa t n aa t +38676a38810 +> whatsoever w aa t s ow eh v er +38890c39024 +< wok w aa k +--- +> wok w ao k +38910a39045 +> wondering w ah n d r ih ng diff --git a/egs/swbd/ASR/local/display_manifest_statistics.py b/egs/swbd/ASR/local/display_manifest_statistics.py new file mode 100755 index 000000000..9aa204863 --- /dev/null +++ b/egs/swbd/ASR/local/display_manifest_statistics.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. + +See the function `remove_short_and_long_utt()` in transducer/train.py +for usage. +""" + + +from lhotse import load_manifest_lazy + + +def main(): + # path = "./data/fbank/swbd_cuts_rt03.jsonl.gz" + path = "./data/fbank/eval2000/eval2000_cuts_all.jsonl.gz" + # path = "./data/fbank/swbd_cuts_all.jsonl.gz" + + cuts = load_manifest_lazy(path) + cuts.describe() + + +if __name__ == "__main__": + main() + +""" +Training Cut statistics: +╒═══════════════════════════╤═══════════╕ +│ Cuts count: │ 167244 │ +├───────────────────────────┼───────────┤ +│ Total duration (hh:mm:ss) │ 281:01:26 │ +├───────────────────────────┼───────────┤ +│ mean │ 6.0 │ +├───────────────────────────┼───────────┤ +│ std │ 3.3 │ +├───────────────────────────┼───────────┤ +│ min │ 2.0 │ +├───────────────────────────┼───────────┤ +│ 25% │ 3.2 │ +├───────────────────────────┼───────────┤ +│ 50% │ 5.2 │ +├───────────────────────────┼───────────┤ +│ 75% │ 8.3 │ +├───────────────────────────┼───────────┤ +│ 99% │ 14.4 │ +├───────────────────────────┼───────────┤ +│ 99.5% │ 14.7 │ +├───────────────────────────┼───────────┤ +│ 99.9% │ 15.0 │ +├───────────────────────────┼───────────┤ +│ max │ 57.5 │ +├───────────────────────────┼───────────┤ +│ Recordings available: │ 167244 │ +├───────────────────────────┼───────────┤ +│ Features available: │ 167244 │ +├───────────────────────────┼───────────┤ +│ Supervisions available: │ 167244 │ +╘═══════════════════════════╧═══════════╛ +Speech duration statistics: +╒══════════════════════════════╤═══════════╤══════════════════════╕ +│ Total speech duration │ 281:01:26 │ 100.00% of recording │ +├──────────────────────────────┼───────────┼──────────────────────┤ +│ Total speaking time duration │ 281:01:26 │ 100.00% of recording │ +├──────────────────────────────┼───────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧═══════════╧══════════════════════╛ + +Eval2000 Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 4473 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 03:37:13 │ +├───────────────────────────┼──────────┤ +│ mean │ 2.9 │ +├───────────────────────────┼──────────┤ +│ std │ 2.6 │ +├───────────────────────────┼──────────┤ +│ min │ 0.1 │ +├───────────────────────────┼──────────┤ +│ 25% │ 1.2 │ +├───────────────────────────┼──────────┤ +│ 50% │ 2.1 │ +├───────────────────────────┼──────────┤ +│ 75% │ 4.0 │ +├───────────────────────────┼──────────┤ +│ 99% │ 12.6 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 13.7 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 14.7 │ +├───────────────────────────┼──────────┤ +│ max │ 15.5 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 4473 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 4473 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 4473 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 03:37:13 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 03:37:13 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +""" diff --git a/egs/swbd/ASR/local/extend_segments.pl b/egs/swbd/ASR/local/extend_segments.pl new file mode 100755 index 000000000..e8b4894d5 --- /dev/null +++ b/egs/swbd/ASR/local/extend_segments.pl @@ -0,0 +1,99 @@ +#!/usr/bin/env perl +use warnings; #sed replacement for -w perl parameter + +if (@ARGV != 1 || !($ARGV[0] =~ m/^-?\d+\.?\d*$/ && $ARGV[0] >= 0)) { + print STDERR "Usage: extend_segments.pl time-in-seconds segments.extended \n" . + "e.g. extend_segments.pl 0.25 segments.2\n" . + "This command modifies a segments file, with lines like\n" . + " \n" . + "by extending the beginning and end of each segment by a certain\n" . + "length of time. This script makes sure the output segments do not\n" . + "overlap as a result of this time-extension, and that there are no\n" . + "negative times in the output.\n"; + exit 1; +} + +$extend = $ARGV[0]; + +@all_lines = (); + +while () { + chop; + @A = split(" ", $_); + if (@A != 4) { + die "invalid line in segments file: $_"; + } + $line = @all_lines; # current number of lines. + ($utt_id, $reco_id, $start_time, $end_time) = @A; + + push @all_lines, [ $utt_id, $reco_id, $start_time, $end_time ]; # anonymous array. + if (! defined $lines_for_reco{$reco_id}) { + $lines_for_reco{$reco_id} = [ ]; # push new anonymous array. + } + push @{$lines_for_reco{$reco_id}}, $line; +} + +foreach $reco_id (keys %lines_for_reco) { + $ref = $lines_for_reco{$reco_id}; + @line_numbers = sort { ${$all_lines[$a]}[2] <=> ${$all_lines[$b]}[2] } @$ref; + + + { + # handle start of earliest segment as a special case. + $l0 = $line_numbers[0]; + $tstart = ${$all_lines[$l0]}[2] - $extend; + if ($tstart < 0.0) { $tstart = 0.0; } + ${$all_lines[$l0]}[2] = $tstart; + } + { + # handle end of latest segment as a special case. + $lN = $line_numbers[$#line_numbers]; + $tend = ${$all_lines[$lN]}[3] + $extend; + ${$all_lines[$lN]}[3] = $tend; + } + for ($i = 0; $i < $#line_numbers; $i++) { + $ln = $line_numbers[$i]; + $ln1 = $line_numbers[$i+1]; + $tend = ${$all_lines[$ln]}[3]; # end of earlier segment. + $tstart = ${$all_lines[$ln1]}[2]; # start of later segment. + if ($tend > $tstart) { + $utt1 = ${$all_lines[$ln]}[0]; + $utt2 = ${$all_lines[$ln1]}[0]; + print STDERR "Warning: for utterances $utt1 and $utt2, segments " . + "already overlap; leaving these times unchanged.\n"; + } else { + $my_extend = $extend; + $max_extend = 0.5 * ($tstart - $tend); + if ($my_extend > $max_extend) { $my_extend = $max_extend; } + $tend += $my_extend; + $tstart -= $my_extend; + ${$all_lines[$ln]}[3] = $tend; + ${$all_lines[$ln1]}[2] = $tstart; + } + } +} + +# leave the numbering of the lines unchanged. +for ($l = 0; $l < @all_lines; $l++) { + $ref = $all_lines[$l]; + ($utt_id, $reco_id, $start_time, $end_time) = @$ref; + printf("%s %s %.2f %.2f\n", $utt_id, $reco_id, $start_time, $end_time); +} + +__END__ + +# testing below. + +# ( echo a1 A 0 1; echo a2 A 3 4; echo b1 B 0 1; echo b2 B 2 3 ) | local/extend_segments.pl 1.0 +a1 A 0.00 2.00 +a2 A 2.00 5.00 +b1 B 0.00 1.50 +b2 B 1.50 4.00 +# ( echo a1 A 0 2; echo a2 A 1 3 ) | local/extend_segments.pl 1.0 +Warning: for utterances a1 and a2, segments already overlap; leaving these times unchanged. +a1 A 0.00 2.00 +a2 A 1.00 4.00 +# ( echo a1 A 0 2; echo a2 A 5 6; echo a3 A 3 4 ) | local/extend_segments.pl 1.0 +a1 A 0.00 2.50 +a2 A 4.50 7.00 +a3 A 2.50 4.50 diff --git a/egs/swbd/ASR/local/filter_cuts.py b/egs/swbd/ASR/local/filter_cuts.py new file mode 100755 index 000000000..fbcc9e24a --- /dev/null +++ b/egs/swbd/ASR/local/filter_cuts.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script removes short and long utterances from a cutset. + +Caution: + You may need to tune the thresholds for your own dataset. + +Usage example: + + python3 ./local/filter_cuts.py \ + --bpe-model data/lang_bpe_500/bpe.model \ + --in-cuts data/fbank/librispeech_cuts_test-clean.jsonl.gz \ + --out-cuts data/fbank-filtered/librispeech_cuts_test-clean.jsonl.gz +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +from lhotse import CutSet, load_manifest_lazy +from lhotse.cut import Cut + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model", + type=Path, + help="Path to the bpe.model", + ) + + parser.add_argument( + "--in-cuts", + type=Path, + help="Path to the input cutset", + ) + + parser.add_argument( + "--out-cuts", + type=Path, + help="Path to the output cutset", + ) + + return parser.parse_args() + + +def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): + total = 0 # number of total utterances before removal + removed = 0 # number of removed utterances + + def remove_short_and_long_utterances(c: Cut): + """Return False to exclude the input cut""" + nonlocal removed, total + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ./display_manifest_statistics.py + # + # You should use ./display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + total += 1 + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + removed += 1 + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./pruned_transducer_stateless2/conformer.py, the + # conv module uses the following expression + # for subsampling + if c.num_frames is None: + num_frames = c.duration * 100 # approximate + else: + num_frames = c.num_frames + + T = ((num_frames - 1) // 2 - 1) // 2 + # Note: for ./lstm_transducer_stateless/lstm.py, the formula is + # T = ((num_frames - 3) // 2 - 1) // 2 + + # Note: for ./pruned_transducer_stateless7/zipformer.py, the formula is + # T = ((num_frames - 7) // 2 + 1) // 2 + + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + removed += 1 + return False + + return True + + # We use to_eager() here so that we can print out the value of total + # and removed below. + ans = cut_set.filter(remove_short_and_long_utterances).to_eager() + ratio = removed / total * 100 + logging.info( + f"Removed {removed} cuts from {total} cuts. {ratio:.3f}% data is removed." + ) + return ans + + +def main(): + args = get_args() + logging.info(vars(args)) + + if args.out_cuts.is_file(): + logging.info(f"{args.out_cuts} already exists - skipping") + return + + assert args.in_cuts.is_file(), f"{args.in_cuts} does not exist" + assert args.bpe_model.is_file(), f"{args.bpe_model} does not exist" + + sp = spm.SentencePieceProcessor() + sp.load(str(args.bpe_model)) + + cut_set = load_manifest_lazy(args.in_cuts) + assert isinstance(cut_set, CutSet) + + cut_set = filter_cuts(cut_set, sp) + logging.info(f"Saving to {args.out_cuts}") + args.out_cuts.parent.mkdir(parents=True, exist_ok=True) + cut_set.to_file(args.out_cuts) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/swbd/ASR/local/filter_empty_text.py b/egs/swbd/ASR/local/filter_empty_text.py new file mode 100755 index 000000000..6b3316800 --- /dev/null +++ b/egs/swbd/ASR/local/filter_empty_text.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# Copyright 2023 The Chinese University of Hong Kong (author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from pathlib import Path +import logging +from typing import List + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--kaldi-data-dir", + type=Path, + required=True, + help="Path to the kaldi data dir", + ) + + return parser.parse_args() + + +def load_segments(path: Path): + segments = {} + with open(path, "r") as f: + lines = f.readlines() + for line in lines: + line = line.strip() + utt_id, rec_id, start, end = line.split() + segments[utt_id] = line + return segments + + +def filter_text(path: Path): + with open(path, "r") as f: + lines = f.readlines() + return list(filter(lambda x: len(x.strip().split()) > 1, lines)) + + +def write_segments(path: Path, texts: List[str]): + with open(path, "w") as f: + f.writelines(texts) + + +def main(): + args = get_args() + orig_text_dict = filter_text(args.kaldi_data_dir / "text") + write_segments(args.kaldi_data_dir / "text", orig_text_dict) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() + + logging.info("Empty lines filtered") diff --git a/egs/swbd/ASR/local/format_acronyms_dict.py b/egs/swbd/ASR/local/format_acronyms_dict.py new file mode 100755 index 000000000..fa598dd03 --- /dev/null +++ b/egs/swbd/ASR/local/format_acronyms_dict.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 + +# Copyright 2015 Minhua Wu +# Apache 2.0 + +# convert acronyms in swbd dict to fisher convention +# IBM to i._b._m. +# BBC to b._b._c. +# BBCs to b._b._c.s +# BBC's to b._b._c.'s + +import argparse +import re + +__author__ = "Minhua Wu" + +parser = argparse.ArgumentParser(description="format acronyms to a._b._c.") +parser.add_argument("-i", "--input", help="Input lexicon", required=True) +parser.add_argument("-o", "--output", help="Output lexicon", required=True) +parser.add_argument( + "-L", "--Letter", help="Input single letter pronunciation", required=True +) +parser.add_argument("-M", "--Map", help="Output acronyms mapping", required=True) +args = parser.parse_args() + + +fin_lex = open(args.input, "r") +fin_Letter = open(args.Letter, "r") +fout_lex = open(args.output, "w") +fout_map = open(args.Map, "w") + +# Initialise single letter dictionary +dict_letter = {} +for single_letter_lex in fin_Letter: + items = single_letter_lex.split() + dict_letter[items[0]] = single_letter_lex[len(items[0]) + 1 :].strip() +fin_Letter.close() +# print dict_letter + +for lex in fin_lex: + items = lex.split() + word = items[0] + lexicon = lex[len(items[0]) + 1 :].strip() + # find acronyms from words with only letters and ' + pre_match = re.match(r"^[A-Za-z]+$|^[A-Za-z]+\'s$|^[A-Za-z]+s$", word) + if pre_match: + # find if words in the form of xxx's is acronym + if word[-2:] == "'s" and (lexicon[-1] == "s" or lexicon[-1] == "z"): + actual_word = word[:-2] + actual_lexicon = lexicon[:-2] + acronym_lexicon = "" + for w in actual_word: + acronym_lexicon = acronym_lexicon + dict_letter[w.upper()] + " " + if acronym_lexicon.strip() == actual_lexicon: + acronym_mapped = "" + acronym_mapped_back = "" + for w in actual_word[:-1]: + acronym_mapped = acronym_mapped + w.lower() + "._" + acronym_mapped_back = acronym_mapped_back + w.lower() + " " + acronym_mapped = acronym_mapped + actual_word[-1].lower() + ".'s" + acronym_mapped_back = ( + acronym_mapped_back + actual_word[-1].lower() + "'s" + ) + fout_map.write( + word + "\t" + acronym_mapped + "\t" + acronym_mapped_back + "\n" + ) + fout_lex.write(acronym_mapped + " " + lexicon + "\n") + else: + fout_lex.write(lex) + + # find if words in the form of xxxs is acronym + elif word[-1] == "s" and (lexicon[-1] == "s" or lexicon[-1] == "z"): + actual_word = word[:-1] + actual_lexicon = lexicon[:-2] + acronym_lexicon = "" + for w in actual_word: + acronym_lexicon = acronym_lexicon + dict_letter[w.upper()] + " " + if acronym_lexicon.strip() == actual_lexicon: + acronym_mapped = "" + acronym_mapped_back = "" + for w in actual_word[:-1]: + acronym_mapped = acronym_mapped + w.lower() + "._" + acronym_mapped_back = acronym_mapped_back + w.lower() + " " + acronym_mapped = acronym_mapped + actual_word[-1].lower() + ".s" + acronym_mapped_back = ( + acronym_mapped_back + actual_word[-1].lower() + "'s" + ) + fout_map.write( + word + "\t" + acronym_mapped + "\t" + acronym_mapped_back + "\n" + ) + fout_lex.write(acronym_mapped + " " + lexicon + "\n") + else: + fout_lex.write(lex) + + # find if words in the form of xxx (not ended with 's or s) is acronym + elif word.find("'") == -1 and word[-1] != "s": + acronym_lexicon = "" + for w in word: + acronym_lexicon = acronym_lexicon + dict_letter[w.upper()] + " " + if acronym_lexicon.strip() == lexicon: + acronym_mapped = "" + acronym_mapped_back = "" + for w in word[:-1]: + acronym_mapped = acronym_mapped + w.lower() + "._" + acronym_mapped_back = acronym_mapped_back + w.lower() + " " + acronym_mapped = acronym_mapped + word[-1].lower() + "." + acronym_mapped_back = acronym_mapped_back + word[-1].lower() + fout_map.write( + word + "\t" + acronym_mapped + "\t" + acronym_mapped_back + "\n" + ) + fout_lex.write(acronym_mapped + " " + lexicon + "\n") + else: + fout_lex.write(lex) + else: + fout_lex.write(lex) + + else: + fout_lex.write(lex) diff --git a/egs/swbd/ASR/local/generate_unique_lexicon.py b/egs/swbd/ASR/local/generate_unique_lexicon.py new file mode 100755 index 000000000..3459c2f5a --- /dev/null +++ b/egs/swbd/ASR/local/generate_unique_lexicon.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file takes as input a lexicon.txt and output a new lexicon, +in which each word has a unique pronunciation. + +The way to do this is to keep only the first pronunciation of a word +in lexicon.txt. +""" + + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +from icefall.lexicon import read_lexicon, write_lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain a file lexicon.txt. + This file will generate a new file uniq_lexicon.txt + in it. + """, + ) + + return parser.parse_args() + + +def filter_multiple_pronunications( + lexicon: List[Tuple[str, List[str]]] +) -> List[Tuple[str, List[str]]]: + """Remove multiple pronunciations of words from a lexicon. + + If a word has more than one pronunciation in the lexicon, only + the first one is kept, while other pronunciations are removed + from the lexicon. + + Args: + lexicon: + The input lexicon, containing a list of (word, [p1, p2, ..., pn]), + where "p1, p2, ..., pn" are the pronunciations of the "word". + Returns: + Return a new lexicon where each word has a unique pronunciation. + """ + seen = set() + ans = [] + + for word, tokens in lexicon: + if word in seen: + continue + seen.add(word) + ans.append((word, tokens)) + return ans + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + + lexicon_filename = lang_dir / "lexicon.txt" + + in_lexicon = read_lexicon(lexicon_filename) + + out_lexicon = filter_multiple_pronunications(in_lexicon) + + write_lexicon(lang_dir / "uniq_lexicon.txt", out_lexicon) + + logging.info(f"Number of entries in lexicon.txt: {len(in_lexicon)}") + logging.info(f"Number of entries in uniq_lexicon.txt: {len(out_lexicon)}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/swbd/ASR/local/map_acronyms_transcripts.py b/egs/swbd/ASR/local/map_acronyms_transcripts.py new file mode 100755 index 000000000..ba02aaec3 --- /dev/null +++ b/egs/swbd/ASR/local/map_acronyms_transcripts.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 + +# Copyright 2015 Minhua Wu +# Apache 2.0 + +# convert acronyms in swbd transcript to fisher convention +# according to first two columns in the input acronyms mapping + +import argparse +import re + +__author__ = "Minhua Wu" + +parser = argparse.ArgumentParser(description="format acronyms to a._b._c.") +parser.add_argument("-i", "--input", help="Input transcripts", required=True) +parser.add_argument("-o", "--output", help="Output transcripts", required=True) +parser.add_argument("-M", "--Map", help="Input acronyms mapping", required=True) +args = parser.parse_args() + +fin_map = open(args.Map, "r") +dict_acronym = {} +dict_acronym_noi = {} # Mapping of acronyms without I, i +for pair in fin_map: + items = pair.split("\t") + dict_acronym[items[0]] = items[1] + dict_acronym_noi[items[0]] = items[1] +fin_map.close() +del dict_acronym_noi["I"] +del dict_acronym_noi["i"] + + +fin_trans = open(args.input, "r") +fout_trans = open(args.output, "w") +for line in fin_trans: + items = line.split() + L = len(items) + # First pass mapping to map I as part of acronym + for i in range(L): + if items[i] == "I": + x = 0 + while i - 1 - x >= 0 and re.match(r"^[A-Z]$", items[i - 1 - x]): + x += 1 + + y = 0 + while i + 1 + y < L and re.match(r"^[A-Z]$", items[i + 1 + y]): + y += 1 + + if x + y > 0: + for bias in range(-x, y + 1): + items[i + bias] = dict_acronym[items[i + bias]] + + # Second pass mapping (not mapping 'i' and 'I') + for i in range(len(items)): + if items[i] in dict_acronym_noi.keys(): + items[i] = dict_acronym_noi[items[i]] + sentence = " ".join(items[1:]) + fout_trans.write(items[0] + " " + sentence.lower() + "\n") + +fin_trans.close() +fout_trans.close() diff --git a/egs/swbd/ASR/local/normalize_and_filter_supervisions.py b/egs/swbd/ASR/local/normalize_and_filter_supervisions.py new file mode 100755 index 000000000..20ab90caf --- /dev/null +++ b/egs/swbd/ASR/local/normalize_and_filter_supervisions.py @@ -0,0 +1,283 @@ +#!/usr/bin/env python3 +# Copyright 2023 (authors: Nagendra Goel https://github.com/ngoel17) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import re +from typing import Tuple + +from lhotse import SupervisionSegment, SupervisionSet +from lhotse.serialization import load_manifest_lazy_or_eager +from tqdm import tqdm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("input_sups") + parser.add_argument("output_sups") + return parser.parse_args() + + +# replacement function to convert lowercase letter to uppercase +def to_upper(match_obj): + if match_obj.group() is not None: + return match_obj.group().upper() + + +def insert_groups_and_capitalize_3(match): + return f"{match.group(1)} {match.group(2)} {match.group(3)}".upper() + + +def insert_groups_and_capitalize_2(match): + return f"{match.group(1)} {match.group(2)}".upper() + + +def insert_groups_and_capitalize_1(match): + return f"{match.group(1)}".upper() + + +def insert_groups_and_capitalize_1s(match): + return f"{match.group(1)}".upper() + "'s" + + +class FisherSwbdNormalizer: + """Note: the functions "normalize" and "keep" implement the logic + similar to Kaldi's data prep scripts for Fisher and SWBD: One + notable difference is that we don't change [cough], [lipsmack], + etc. to [noise]. We also don't implement all the edge cases of + normalization from Kaldi (hopefully won't make too much + difference). + """ + + def __init__(self) -> None: + self.remove_regexp_before = re.compile( + r"|".join( + [ + # special symbols + r"\[\[skip.*\]\]", + r"\[skip.*\]", + r"\[pause.*\]", + r"\[silence\]", + r"", + r"", + r"_1", + ] + ) + ) + + # tuples of (pattern, replacement) + # note: Kaldi replaces sighs, coughs, etc with [noise]. + # We don't do that here. + # We also lowercase the text as the first operation. + self.replace_regexps: Tuple[re.Pattern, str] = [ + # SWBD: + # [LAUGHTER-STORY] -> STORY + (re.compile(r"\[laughter-(.*?)\]"), r"\1"), + # [WEA[SONABLE]-/REASONABLE] + (re.compile(r"\[\S+/(\S+)\]"), r"\1"), + # -[ADV]AN[TAGE]- -> AN + (re.compile(r"-?\[.*?\](\w+)\[.*?\]-?"), r"\1-"), + # ABSOLUTE[LY]- -> ABSOLUTE- + (re.compile(r"(\w+)\[.*?\]-?"), r"\1-"), + # [AN]Y- -> Y- + # -[AN]Y- -> Y- + (re.compile(r"-?\[.*?\](\w+)-?"), r"\1-"), + # special tokens + (re.compile(r"\[laugh.*?\]"), r"[laughter]"), + (re.compile(r"\[sigh.*?\]"), r"[sigh]"), + (re.compile(r"\[cough.*?\]"), r"[cough]"), + (re.compile(r"\[mn.*?\]"), r"[vocalized-noise]"), + (re.compile(r"\[breath.*?\]"), r"[breath]"), + (re.compile(r"\[lipsmack.*?\]"), r"[lipsmack]"), + (re.compile(r"\[sneeze.*?\]"), r"[sneeze]"), + # abbreviations + ( + re.compile( + r"(\w)\.(\w)\.(\w)", + ), + insert_groups_and_capitalize_3, + ), + ( + re.compile( + r"(\w)\.(\w)", + ), + insert_groups_and_capitalize_2, + ), + ( + re.compile( + r"([a-h,j-z])\.", + ), + insert_groups_and_capitalize_1, + ), + ( + re.compile( + r"\._", + ), + r" ", + ), + ( + re.compile( + r"_(\w)", + ), + insert_groups_and_capitalize_1, + ), + ( + re.compile( + r"(\w)\.s", + ), + insert_groups_and_capitalize_1s, + ), + ( + re.compile( + r"([A-Z])\'s", + ), + insert_groups_and_capitalize_1s, + ), + ( + re.compile( + r"(\s\w\b|^\w\b)", + ), + insert_groups_and_capitalize_1, + ), + # words between apostrophes + (re.compile(r"'(\S*?)'"), r"\1"), + # dangling dashes (2 passes) + (re.compile(r"\s-\s"), r" "), + (re.compile(r"\s-\s"), r" "), + # special symbol with trailing dash + (re.compile(r"(\[.*?\])-"), r"\1"), + # Just remove all dashes + (re.compile(r"-"), r" "), + ] + + # unwanted symbols in the transcripts + self.remove_regexp_after = re.compile( + r"|".join( + [ + # remaining punctuation + r"\.", + r",", + r"\?", + r"{", + r"}", + r"~", + r"_\d", + ] + ) + ) + + self.post_fixes = [ + # Fix an issue related to [VOCALIZED NOISE] after dash removal + (re.compile(r"\[vocalized noise\]"), "[vocalized-noise]"), + ] + + self.whitespace_regexp = re.compile(r"\s+") + + def normalize(self, text: str) -> str: + text = text.lower() + + # first remove + text = self.remove_regexp_before.sub("", text) + + # then replace + for pattern, sub in self.replace_regexps: + text = pattern.sub(sub, text) + + # then remove + text = self.remove_regexp_after.sub("", text) + + # post fixes + for pattern, sub in self.post_fixes: + text = pattern.sub(sub, text) + + # then clean up whitespace + text = self.whitespace_regexp.sub(" ", text).strip() + + return text.upper() + + +def keep(sup: SupervisionSegment) -> bool: + if "((" in sup.text: + return False + + if " yes", + "[laugh] oh this is [laught] this is great [silence] yes", + "i don't kn- - know A.B.C's", + "so x. corp is good?", + "'absolutely yes", + "absolutely' yes", + "'absolutely' yes", + "'absolutely' yes 'aight", + "ABSOLUTE[LY]", + "ABSOLUTE[LY]-", + "[AN]Y", + "[AN]Y-", + "[ADV]AN[TAGE]", + "[ADV]AN[TAGE]-", + "-[ADV]AN[TAGE]", + "-[ADV]AN[TAGE]-", + "[WEA[SONABLE]-/REASONABLE]", + "[VOCALIZED-NOISE]-", + "~BULL", + "Frank E Peretti P E R E T T I", + "yeah yeah like Double O Seven he's supposed to do it", + "P A P E R paper", + "[noise] okay_1 um let me see [laughter] i've been sitting here awhile", + ]: + print(text) + print(normalizer.normalize(text)) + print() + + +if __name__ == "__main__": + test() + # exit() + main() diff --git a/egs/swbd/ASR/local/normalize_eval2000.py b/egs/swbd/ASR/local/normalize_eval2000.py new file mode 100755 index 000000000..7316193d0 --- /dev/null +++ b/egs/swbd/ASR/local/normalize_eval2000.py @@ -0,0 +1,234 @@ +#!/usr/bin/env python3 +# Copyright 2023 (authors: Nagendra Goel https://github.com/ngoel17) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import re +from typing import Tuple + +from lhotse import SupervisionSegment, SupervisionSet +from lhotse.serialization import load_manifest_lazy_or_eager +from tqdm import tqdm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("input_sups") + parser.add_argument("output_sups") + return parser.parse_args() + + +def remove_punctutation_and_other_symbol(text: str) -> str: + text = text.replace("--", " ") + text = text.replace("//", " ") + text = text.replace(".", " ") + text = text.replace("?", " ") + text = text.replace("~", " ") + text = text.replace(",", " ") + text = text.replace(";", " ") + text = text.replace("(", " ") + text = text.replace(")", " ") + text = text.replace("&", " ") + text = text.replace("%", " ") + text = text.replace("*", " ") + text = text.replace("{", " ") + text = text.replace("}", " ") + return text + + +def eval2000_clean_eform(text: str, eform_count) -> str: + string_to_remove = [] + piece = text.split('">') + for i in range(0, len(piece)): + s = piece[i] + '">' + res = re.search(r"", s) + if res is not None: + res_rm = res.group(1) + string_to_remove.append(res_rm) + for p in string_to_remove: + eform_string = p + text = text.replace(eform_string, " ") + eform_1 = " str: + text = text.replace("[/BABY CRYING]", " ") + text = text.replace("[/CHILD]", " ") + text = text.replace("[[DISTORTED]]", " ") + text = text.replace("[/DISTORTION]", " ") + text = text.replace("[[DRAWN OUT]]", " ") + text = text.replace("[[DRAWN-OUT]]", " ") + text = text.replace("[[FAINT]]", " ") + text = text.replace("[SMACK]", " ") + text = text.replace("[[MUMBLES]]", " ") + text = text.replace("[[HIGH PITCHED SQUEAKY VOICE]]", " ") + text = text.replace("[[IN THE LAUGH]]", "[LAUGHTER]") + text = text.replace("[[LAST WORD SPOKEN WITH A LAUGH]]", "[LAUGHTER]") + text = text.replace("[[PART OF FIRST SYLLABLE OF PREVIOUS WORD CUT OFF]]", " ") + text = text.replace("[[PREVIOUS WORD SPOKEN WITH A LAUGH]]", " ") + text = text.replace("[[PREVIOUS TWO WORDS SPOKEN WHILE LAUGHING]]", " ") + text = text.replace("[[PROLONGED]]", " ") + text = text.replace("[/RUNNING WATER]", " ") + text = text.replace("[[SAYS LAUGHING]]", "[LAUGHTER]") + text = text.replace("[[SINGING]]", " ") + text = text.replace("[[SPOKEN WHILE LAUGHING]]", "[LAUGHTER]") + text = text.replace("[/STATIC]", " ") + text = text.replace("['THIRTIETH' DRAWN OUT]", " ") + text = text.replace("[/VOICES]", " ") + text = text.replace("[[WHISPERED]]", " ") + text = text.replace("[DISTORTION]", " ") + text = text.replace("[DISTORTION, HIGH VOLUME ON WAVES]", " ") + text = text.replace("[BACKGROUND LAUGHTER]", "[LAUGHTER]") + text = text.replace("[CHILD'S VOICE]", " ") + text = text.replace("[CHILD SCREAMS]", " ") + text = text.replace("[CHILD VOICE]", " ") + text = text.replace("[CHILD YELLING]", " ") + text = text.replace("[CHILD SCREAMING]", " ") + text = text.replace("[CHILD'S VOICE IN BACKGROUND]", " ") + text = text.replace("[CHANNEL NOISE]", " ") + text = text.replace("[CHANNEL ECHO]", " ") + text = text.replace("[ECHO FROM OTHER CHANNEL]", " ") + text = text.replace("[ECHO OF OTHER CHANNEL]", " ") + text = text.replace("[CLICK]", " ") + text = text.replace("[DISTORTED]", " ") + text = text.replace("[BABY CRYING]", " ") + text = text.replace("[METALLIC KNOCKING SOUND]", " ") + text = text.replace("[METALLIC SOUND]", " ") + + text = text.replace("[PHONE JIGGLING]", " ") + text = text.replace("[BACKGROUND SOUND]", " ") + text = text.replace("[BACKGROUND VOICE]", " ") + text = text.replace("[BACKGROUND VOICES]", " ") + text = text.replace("[BACKGROUND NOISE]", " ") + text = text.replace("[CAR HORNS IN BACKGROUND]", " ") + text = text.replace("[CAR HORNS]", " ") + text = text.replace("[CARNATING]", " ") + text = text.replace("[CRYING CHILD]", " ") + text = text.replace("[CHOPPING SOUND]", " ") + text = text.replace("[BANGING]", " ") + text = text.replace("[CLICKING NOISE]", " ") + text = text.replace("[CLATTERING]", " ") + text = text.replace("[ECHO]", " ") + text = text.replace("[KNOCK]", " ") + text = text.replace("[NOISE-GOOD]", "[NOISE]") + text = text.replace("[RIGHT]", " ") + text = text.replace("[SOUND]", " ") + text = text.replace("[SQUEAK]", " ") + text = text.replace("[STATIC]", " ") + text = text.replace("[[SAYS WITH HIGH-PITCHED SCREAMING LAUGHTER]]", " ") + text = text.replace("[UH]", "UH") + text = text.replace("[MN]", "[VOCALIZED-NOISE]") + text = text.replace("[VOICES]", " ") + text = text.replace("[WATER RUNNING]", " ") + text = text.replace("[SOUND OF TWISTING PHONE CORD]", " ") + text = text.replace("[SOUND OF SOMETHING FALLING]", " ") + text = text.replace("[SOUND]", " ") + text = text.replace("[NOISE OF MOVING PHONE]", " ") + text = text.replace("[SOUND OF RUNNING WATER]", " ") + text = text.replace("[CHANNEL]", " ") + text = text.replace("[SILENCE]", " ") + text = text.replace("-[W]HERE", "WHERE") + text = text.replace("Y[OU]I-", "YOU I") + text = text.replace("-[A]ND", "AND") + text = text.replace("JU[ST]", "JUST") + text = text.replace("{BREATH}", " ") + text = text.replace("{BREATHY}", " ") + text = text.replace("{CHANNEL NOISE}", " ") + text = text.replace("{CLEAR THROAT}", " ") + + text = text.replace("{CLEARING THROAT}", " ") + text = text.replace("{CLEARS THROAT}", " ") + text = text.replace("{COUGH}", " ") + text = text.replace("{DRAWN OUT}", " ") + text = text.replace("{EXHALATION}", " ") + text = text.replace("{EXHALE}", " ") + text = text.replace("{GASP}", " ") + text = text.replace("{HIGH SQUEAL}", " ") + text = text.replace("{INHALE}", " ") + text = text.replace("{LAUGH}", "[LAUGHTER]") + text = text.replace("{LAUGH}", "[LAUGHTER]") + text = text.replace("{LAUGH}", "[LAUGHTER]") + text = text.replace("{LIPSMACK}", " ") + text = text.replace("{LIPSMACK}", " ") + + text = text.replace("{NOISE OF DISGUST}", " ") + text = text.replace("{SIGH}", " ") + text = text.replace("{SNIFF}", " ") + text = text.replace("{SNORT}", " ") + text = text.replace("{SHARP EXHALATION}", " ") + text = text.replace("{BREATH LAUGH}", " ") + + text = text.replace("[LAUGHTER]", " ") + text = text.replace("[NOISE]", " ") + text = text.replace("[VOCALIZED-NOISE]", " ") + text = text.replace("-", " ") + return text + + +def remove_languagetag(text: str) -> str: + langtag = re.findall(r"<(.*?)>", text) + for t in langtag: + text = text.replace(t, " ") + text = text.replace("<", " ") + text = text.replace(">", " ") + return text + + +def eval2000_normalizer(text: str) -> str: + # print("TEXT original: ",text) + eform_count = text.count("contraction e_form") + # print("eform corunt:", eform_count) + if eform_count > 0: + text = eval2000_clean_eform(text, eform_count) + text = text.upper() + text = remove_languagetag(text) + text = replace_silphone(text) + text = remove_punctutation_and_other_symbol(text) + text = text.replace("IGNORE_TIME_SEGMENT_IN_SCORING", " ") + text = text.replace("IGNORE_TIME_SEGMENT_SCORING", " ") + spaces = re.findall(r"\s+", text) + for sp in spaces: + text = text.replace(sp, " ") + text = text.strip() + # text = self.whitespace_regexp.sub(" ", text).strip() + # print(text) + return text + + +def main(): + args = get_args() + sups = load_manifest_lazy_or_eager(args.input_sups) + assert isinstance(sups, SupervisionSet) + + tot, skip = 0, 0 + with SupervisionSet.open_writer(args.output_sups) as writer: + for sup in tqdm(sups, desc="Normalizing supervisions"): + tot += 1 + sup.text = eval2000_normalizer(sup.text) + if not sup.text: + skip += 1 + continue + writer.write(sup) + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/local/prepare_lang.py b/egs/swbd/ASR/local/prepare_lang.py new file mode 120000 index 000000000..747f2ab39 --- /dev/null +++ b/egs/swbd/ASR/local/prepare_lang.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/swbd/ASR/local/prepare_lang_bpe.py b/egs/swbd/ASR/local/prepare_lang_bpe.py new file mode 100755 index 000000000..d82a085ec --- /dev/null +++ b/egs/swbd/ASR/local/prepare_lang_bpe.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +""" + +This script takes as input `lang_dir`, which should contain:: + + - lang_dir/bpe.model, + - lang_dir/words.txt + +and generates the following files in the directory `lang_dir`: + + - lexicon.txt + - lexicon_disambig.txt + - L.pt + - L_disambig.pt + - tokens.txt +""" + +import argparse +from pathlib import Path +from typing import Dict, List, Tuple + +import k2 +import sentencepiece as spm +import torch +from prepare_lang import ( + Lexicon, + add_disambig_symbols, + add_self_loops, + write_lexicon, + write_mapping, +) + +from icefall.utils import str2bool + + +def lexicon_to_fst_no_sil( + lexicon: Lexicon, + token2id: Dict[str, int], + word2id: Dict[str, int], + need_self_loops: bool = False, +) -> k2.Fsa: + """Convert a lexicon to an FST (in k2 format). + + Args: + lexicon: + The input lexicon. See also :func:`read_lexicon` + token2id: + A dict mapping tokens to IDs. + word2id: + A dict mapping words to IDs. + need_self_loops: + If True, add self-loop to states with non-epsilon output symbols + on at least one arc out of the state. The input label for this + self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. + Returns: + Return an instance of `k2.Fsa` representing the given lexicon. + """ + loop_state = 0 # words enter and leave from here + next_state = 1 # the next un-allocated state, will be incremented as we go + + arcs = [] + + # The blank symbol is defined in local/train_bpe_model.py + assert token2id[""] == 0 + assert word2id[""] == 0 + + eps = 0 + + for word, pieces in lexicon: + assert len(pieces) > 0, f"{word} has no pronunciations" + cur_state = loop_state + + word = word2id[word] + pieces = [token2id[i] for i in pieces] + + for i in range(len(pieces) - 1): + w = word if i == 0 else eps + arcs.append([cur_state, next_state, pieces[i], w, 0]) + + cur_state = next_state + next_state += 1 + + # now for the last piece of this word + i = len(pieces) - 1 + w = word if i == 0 else eps + arcs.append([cur_state, loop_state, pieces[i], w, 0]) + + if need_self_loops: + disambig_token = token2id["#0"] + disambig_word = word2id["#0"] + arcs = add_self_loops( + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, + ) + + final_state = next_state + arcs.append([loop_state, final_state, -1, -1, 0]) + arcs.append([final_state]) + + arcs = sorted(arcs, key=lambda arc: arc[0]) + arcs = [[str(i) for i in arc] for arc in arcs] + arcs = [" ".join(arc) for arc in arcs] + arcs = "\n".join(arcs) + + fsa = k2.Fsa.from_str(arcs, acceptor=False) + return fsa + + +def generate_lexicon( + model_file: str, words: List[str], oov: str +) -> Tuple[Lexicon, Dict[str, int]]: + """Generate a lexicon from a BPE model. + + Args: + model_file: + Path to a sentencepiece model. + words: + A list of strings representing words. + oov: + The out of vocabulary word in lexicon. + Returns: + Return a tuple with two elements: + - A dict whose keys are words and values are the corresponding + word pieces. + - A dict representing the token symbol, mapping from tokens to IDs. + """ + sp = spm.SentencePieceProcessor() + sp.load(str(model_file)) + + # Convert word to word piece IDs instead of word piece strings + # to avoid OOV tokens. + words_pieces_ids: List[List[int]] = sp.encode(words, out_type=int) + + # Now convert word piece IDs back to word piece strings. + words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids] + + lexicon = [] + for word, pieces in zip(words, words_pieces): + lexicon.append((word, pieces)) + + lexicon.append((oov, ["▁", sp.id_to_piece(sp.unk_id())])) + + token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())} + + return lexicon, token2id + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain the bpe.model and words.txt + """, + ) + + parser.add_argument( + "--oov", + type=str, + default="", + help="The out of vocabulary word in lexicon.", + ) + + parser.add_argument( + "--debug", + type=str2bool, + default=False, + help="""True for debugging, which will generate + a visualization of the lexicon FST. + + Caution: If your lexicon contains hundreds of thousands + of lines, please set it to False! + + See "test/test_bpe_lexicon.py" for usage. + """, + ) + + return parser.parse_args() + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + model_file = lang_dir / "bpe.model" + + word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") + + words = word_sym_table.symbols + + excluded = [ + "", + "!SIL", + "", + args.oov, + "#0", + "", + "", + ] + + for w in excluded: + if w in words: + words.remove(w) + + lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov) + + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + + next_token_id = max(token_sym_table.values()) + 1 + for i in range(max_disambig + 1): + disambig = f"#{i}" + assert disambig not in token_sym_table + token_sym_table[disambig] = next_token_id + next_token_id += 1 + + word_sym_table.add("#0") + word_sym_table.add("") + word_sym_table.add("") + + write_mapping(lang_dir / "tokens.txt", token_sym_table) + + write_lexicon(lang_dir / "lexicon.txt", lexicon) + write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) + + L = lexicon_to_fst_no_sil( + lexicon, + token2id=token_sym_table, + word2id=word_sym_table, + ) + + L_disambig = lexicon_to_fst_no_sil( + lexicon_disambig, + token2id=token_sym_table, + word2id=word_sym_table, + need_self_loops=True, + ) + torch.save(L.as_dict(), lang_dir / "L.pt") + torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") + + if args.debug: + labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") + aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") + + L.labels_sym = labels_sym + L.aux_labels_sym = aux_labels_sym + L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") + + L_disambig.labels_sym = labels_sym + L_disambig.aux_labels_sym = aux_labels_sym + L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/local/prepare_lm_training_data.py b/egs/swbd/ASR/local/prepare_lm_training_data.py new file mode 120000 index 000000000..abc00d421 --- /dev/null +++ b/egs/swbd/ASR/local/prepare_lm_training_data.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lm_training_data.py \ No newline at end of file diff --git a/egs/swbd/ASR/local/rt03_data_prep.sh b/egs/swbd/ASR/local/rt03_data_prep.sh new file mode 100755 index 000000000..8a5f64324 --- /dev/null +++ b/egs/swbd/ASR/local/rt03_data_prep.sh @@ -0,0 +1,107 @@ +#!/usr/bin/env bash + +# RT-03 data preparation (conversational telephone speech part only) +# Adapted from Arnab Ghoshal's script for Hub-5 Eval 2000 by Peng Qi + +# To be run from one directory above this script. + +# Expects the standard directory layout for RT-03 + +if [ $# -ne 1 ]; then + echo "Usage: $0 " + echo "e.g.: $0 /export/corpora/LDC/LDC2007S10" + echo "See comments in the script for more details" + exit 1 +fi + +sdir=$1 +[ ! -d $sdir/data/audio/eval03/english/cts ] && + echo Expecting directory $sdir/data/audio/eval03/english/cts to be present && exit 1 +[ ! -d $sdir/data/references/eval03/english/cts ] && + echo Expecting directory $tdir/data/references/eval03/english/cts to be present && exit 1 + +dir=data/local/rt03 +mkdir -p $dir + +rtroot=$sdir +tdir=$sdir/data/references/eval03/english/cts +sdir=$sdir/data/audio/eval03/english/cts + +find -L $sdir -iname '*.sph' | sort >$dir/sph.flist +sed -e 's?.*/??' -e 's?.sph??' $dir/sph.flist | paste - $dir/sph.flist \ + >$dir/sph.scp + +sph2pipe=sph2pipe +! command -v "${sph2pipe}" &>/dev/null && + echo "Could not execute the sph2pipe program at $sph2pipe" && exit 1 + +awk -v sph2pipe=$sph2pipe '{ + printf("%s-A %s -f wav -p -c 1 %s |\n", $1, sph2pipe, $2); + printf("%s-B %s -f wav -p -c 2 %s |\n", $1, sph2pipe, $2); +}' <$dir/sph.scp | sort >$dir/wav.scp || exit 1 +#side A - channel 1, side B - channel 2 + +# Get segments file... +# segments file format is: utt-id side-id start-time end-time, e.g.: +# sw02001-A_000098-001156 sw02001-A 0.98 11.56 +#pem=$sdir/english/hub5e_00.pem +#[ ! -f $pem ] && echo "No such file $pem" && exit 1; +# pem file has lines like: +# en_4156 A unknown_speaker 301.85 302.48 + +#grep -v ';;' $pem \ +cat $tdir/*.stm | grep -v ';;' | grep -v inter_segment_gap | + awk '{ + spk=$1"-"(($2==1)?"A":"B"); + utt=sprintf("%s_%06d-%06d",spk,$4*100,$5*100); + print utt,spk,$4,$5;}' | + sort -u >$dir/segments + +# stm file has lines like: +# en_4156 A en_4156_A 357.64 359.64 HE IS A POLICE OFFICER +# TODO(arnab): We should really be lowercasing this since the Edinburgh +# recipe uses lowercase. This is not used in the actual scoring. +#grep -v ';;' $tdir/reference/hub5e00.english.000405.stm \ +cat $tdir/*.stm | grep -v ';;' | grep -v inter_segment_gap | + awk '{ + spk=$1"-"(($2==1)?"A":"B"); + utt=sprintf("%s_%06d-%06d",spk,$4*100,$5*100); + printf utt; for(n=7;n<=NF;n++) printf(" %s", $n); print ""; }' | + sort >$dir/text.all + +# We'll use the stm file for sclite scoring. There seem to be various errors +# in the stm file that upset hubscr.pl, and we fix them here. +cat $tdir/*.stm | + sed -e 's:((:(:' -e 's:::g' -e 's:::g' | + grep -v inter_segment_gap | + awk '{ + printf $1; if ($1==";;") printf(" %s",$2); else printf(($2==1)?" A":" B"); for(n=3;n<=NF;n++) printf(" %s", $n); print ""; }' \ + >$dir/stm +#$tdir/reference/hub5e00.english.000405.stm > $dir/stm +cp $rtroot/data/trans_rules/en20030506.glm $dir/glm + +# next line uses command substitution +# Just checking that the segments are the same in pem vs. stm. +! cmp <(awk '{print $1}' $dir/text.all) <(awk '{print $1}' $dir/segments) && + echo "Segments from pem file and stm file do not match." && exit 1 + +grep -v IGNORE_TIME_SEGMENT_ $dir/text.all >$dir/text + +# create an utt2spk file that assumes each conversation side is +# a separate speaker. +awk '{print $1,$2;}' $dir/segments >$dir/utt2spk +utils/utt2spk_to_spk2utt.pl $dir/utt2spk >$dir/spk2utt + +# cp $dir/segments $dir/segments.tmp +# awk '{x=$3-0.05; if (x<0.0) x=0.0; y=$4+0.05; print $1, $2, x, y; }' \ +# $dir/segments.tmp > $dir/segments + +awk '{print $1}' $dir/wav.scp | + perl -ane '$_ =~ m:^(\S+)-([AB])$: || die "bad label $_"; + print "$1-$2 $1 $2\n"; ' \ + >$dir/reco2file_and_channel || exit 1 + +./utils/fix_data_dir.sh $dir + +echo Data preparation and formatting completed for RT-03 +echo "(but not MFCC extraction)" diff --git a/egs/swbd/ASR/local/sort_lm_training_data.py b/egs/swbd/ASR/local/sort_lm_training_data.py new file mode 100755 index 000000000..bed3856e4 --- /dev/null +++ b/egs/swbd/ASR/local/sort_lm_training_data.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file takes as input the filename of LM training data +generated by ./local/prepare_lm_training_data.py and sorts +it by sentence length. + +Sentence length equals to the number of BPE tokens in a sentence. +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import numpy as np +import torch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--in-lm-data", + type=str, + help="Input LM training data, e.g., data/bpe_500/lm_data.pt", + ) + + parser.add_argument( + "--out-lm-data", + type=str, + help="Input LM training data, e.g., data/bpe_500/sorted_lm_data.pt", + ) + + parser.add_argument( + "--out-statistics", + type=str, + help="Statistics about LM training data., data/bpe_500/statistics.txt", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + in_lm_data = Path(args.in_lm_data) + out_lm_data = Path(args.out_lm_data) + assert in_lm_data.is_file(), f"{in_lm_data}" + if out_lm_data.is_file(): + logging.warning(f"{out_lm_data} exists - skipping") + return + data = torch.load(in_lm_data) + words2bpe = data["words"] + sentences = data["sentences"] + sentence_lengths = data["sentence_lengths"] + + num_sentences = sentences.dim0 + assert num_sentences == sentence_lengths.numel(), ( + num_sentences, + sentence_lengths.numel(), + ) + + indices = torch.argsort(sentence_lengths, descending=True) + + sorted_sentences = sentences[indices.to(torch.int32)] + sorted_sentence_lengths = sentence_lengths[indices] + + # Check that sentences are ordered by length + assert num_sentences == sorted_sentences.dim0, ( + num_sentences, + sorted_sentences.dim0, + ) + + cur = None + for i in range(num_sentences): + word_ids = sorted_sentences[i] + token_ids = words2bpe[word_ids] + if isinstance(token_ids, k2.RaggedTensor): + token_ids = token_ids.values + if cur is not None: + assert cur >= token_ids.numel(), (cur, token_ids.numel()) + + cur = token_ids.numel() + assert cur == sorted_sentence_lengths[i] + + data["sentences"] = sorted_sentences + data["sentence_lengths"] = sorted_sentence_lengths + torch.save(data, args.out_lm_data) + logging.info(f"Saved to {args.out_lm_data}") + + statistics = Path(args.out_statistics) + + # Write statistics + num_words = sorted_sentences.numel() + num_tokens = sentence_lengths.sum().item() + max_sentence_length = sentence_lengths[indices[0]] + min_sentence_length = sentence_lengths[indices[-1]] + + step = 10 + hist, bins = np.histogram( + sentence_lengths.numpy(), + bins=np.arange(1, max_sentence_length + step, step), + ) + + histogram = np.stack((bins[:-1], hist)).transpose() + + with open(statistics, "w") as f: + f.write(f"num_sentences: {num_sentences}\n") + f.write(f"num_words: {num_words}\n") + f.write(f"num_tokens: {num_tokens}\n") + f.write(f"max_sentence_length: {max_sentence_length}\n") + f.write(f"min_sentence_length: {min_sentence_length}\n") + f.write("histogram:\n") + f.write(" bin count percent\n") + for row in histogram: + f.write( + f"{int(row[0]):>5} {int(row[1]):>5} " + f"{100.*row[1]/num_sentences:.3f}%\n" + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/swbd/ASR/local/swbd1_data_prep.sh b/egs/swbd/ASR/local/swbd1_data_prep.sh new file mode 100755 index 000000000..159359491 --- /dev/null +++ b/egs/swbd/ASR/local/swbd1_data_prep.sh @@ -0,0 +1,128 @@ +#!/usr/bin/env bash + +# Switchboard-1 training data preparation customized for Edinburgh +# Author: Arnab Ghoshal (Jan 2013) + +# To be run from one directory above this script. + +## The input is some directory containing the switchboard-1 release 2 +## corpus (LDC97S62). Note: we don't make many assumptions about how +## you unpacked this. We are just doing a "find" command to locate +## the .sph files. + +## The second input is optional, which should point to a directory containing +## Switchboard transcriptions/documentations (specifically, the conv.tab file). +## If specified, the script will try to use the actual speaker PINs provided +## with the corpus instead of the conversation side ID (Kaldi default). We +## will be using "find" to locate this file so we don't make any assumptions +## on the directory structure. (Peng Qi, Aug 2014) + +#check existing directories +if [ $# != 1 -a $# != 2 ]; then + echo "Usage: swbd1_data_prep.sh /path/to/SWBD [/path/to/SWBD_DOC]" + exit 1 +fi + +SWBD_DIR=$1 + +dir=data/local/train +mkdir -p $dir + +# Audio data directory check +if [ ! -d $SWBD_DIR ]; then + echo "Error: run.sh requires a directory argument" + exit 1 +fi + +sph2pipe=sph2pipe +! command -v "${sph2pipe}" &>/dev/null && + echo "Could not execute the sph2pipe program at $sph2pipe" && exit 1 + +# Option A: SWBD dictionary file check +[ ! -f ./swb_ms98_transcriptions/sw-ms98-dict.text ] && + echo "SWBD dictionary file does not exist" && exit 1 + +# find sph audio files +find -L $SWBD_DIR -iname '*.sph' | sort >$dir/sph.flist + +n=$(cat $dir/sph.flist | wc -l) +[ $n -ne 2435 ] && [ $n -ne 2438 ] && + echo Warning: expected 2435 or 2438 data data files, found $n + +# (1a) Transcriptions preparation +# make basic transcription file (add segments info) +# **NOTE: In the default Kaldi recipe, everything is made uppercase, while we +# make everything lowercase here. This is because we will be using SRILM which +# can optionally make everything lowercase (but not uppercase) when mapping +# LM vocabs. +awk '{ +name=substr($1,1,6); gsub("^sw","sw0",name); side=substr($1,7,1); +stime=$2; etime=$3; +printf("%s-%s_%06.0f-%06.0f", +name, side, int(100*stime+0.5), int(100*etime+0.5)); +for(i=4;i<=NF;i++) printf(" %s", $i); printf "\n" +}' ./swb_ms98_transcriptions/*/*/*-trans.text >$dir/transcripts1.txt + +# test if trans. file is sorted +export LC_ALL=C +sort -c $dir/transcripts1.txt || exit 1 # check it's sorted. + +# Remove SILENCE, and . + +# Note: we have [NOISE], [VOCALIZED-NOISE], [LAUGHTER], [SILENCE]. +# removing [SILENCE], and the and markers that mark +# speech to somone; we will give phones to the other three (NSN, SPN, LAU). +# There will also be a silence phone, SIL. +# **NOTE: modified the pattern matches to make them case insensitive +cat $dir/transcripts1.txt | + perl -ane 's:\s\[SILENCE\](\s|$):$1:gi; + s///gi; + s///gi; + print;' | + awk '{if(NF > 1) { print; } } ' >$dir/transcripts2.txt + +# **NOTE: swbd1_map_words.pl has been modified to make the pattern matches +# case insensitive +local/swbd1_map_words.pl -f 2- $dir/transcripts2.txt >$dir/text + +# format acronyms in text +python3 local/map_acronyms_transcripts.py -i $dir/text -o $dir/text_map \ + -M data/local/dict_nosp/acronyms.map +mv $dir/text_map $dir/text + +# (1c) Make segment files from transcript +#segments file format is: utt-id side-id start-time end-time, e.g.: +#sw02001-A_000098-001156 sw02001-A 0.98 11.56 +awk '{ +segment=$1; +split(segment,S,"[_-]"); +side=S[2]; audioname=S[1]; startf=S[3]; endf=S[4]; +print segment " " audioname "-" side " " startf/100 " " endf/100 +}' <$dir/text >$dir/segments + +sed -e 's?.*/??' -e 's?.sph??' $dir/sph.flist | paste - $dir/sph.flist \ + >$dir/sph.scp + +awk -v sph2pipe=$sph2pipe '{ +printf("%s-A %s -f wav -p -c 1 %s |\n", $1, sph2pipe, $2); +printf("%s-B %s -f wav -p -c 2 %s |\n", $1, sph2pipe, $2); +}' <$dir/sph.scp | sort >$dir/wav.scp || exit 1 +#side A - channel 1, side B - channel 2 + +# this file reco2file_and_channel maps recording-id (e.g. sw02001-A) +# to the file name sw02001 and the A, e.g. +# sw02001-A sw02001 A +# In this case it's trivial, but in other corpora the information might +# be less obvious. Later it will be needed for ctm scoring. +awk '{print $1}' $dir/wav.scp | + perl -ane '$_ =~ m:^(\S+)-([AB])$: || die "bad label $_"; + print "$1-$2 $1 $2\n"; ' \ + >$dir/reco2file_and_channel || exit 1 + +awk '{spk=substr($1,1,9); print $1 " " spk}' $dir/segments >$dir/utt2spk || + exit 1 +sort -k 2 $dir/utt2spk | utils/utt2spk_to_spk2utt.pl >$dir/spk2utt || exit 1 + +echo Switchboard-1 data preparation succeeded. + +utils/fix_data_dir.sh data/local/train diff --git a/egs/swbd/ASR/local/swbd1_map_words.pl b/egs/swbd/ASR/local/swbd1_map_words.pl new file mode 100755 index 000000000..4fb8d4ffe --- /dev/null +++ b/egs/swbd/ASR/local/swbd1_map_words.pl @@ -0,0 +1,52 @@ +#!/usr/bin/env perl + +# Modified from swbd_map_words.pl in Kaldi s5 recipe to make pattern +# matches case-insensitive --Arnab (Jan 2013) + +if ($ARGV[0] eq "-f") { + shift @ARGV; + $field_spec = shift @ARGV; + if ($field_spec =~ m/^\d+$/) { + $field_begin = $field_spec - 1; $field_end = $field_spec - 1; + } + if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesy (properly, 1-10) + if ($1 ne "") { + $field_begin = $1 - 1; # Change to zero-based indexing. + } + if ($2 ne "") { + $field_end = $2 - 1; # Change to zero-based indexing. + } + } + if (!defined $field_begin && !defined $field_end) { + die "Bad argument to -f option: $field_spec"; + } +} + + +while (<>) { + @A = split(" ", $_); + for ($n = 0; $n < @A; $n++) { + $a = $A[$n]; + if ( (!defined $field_begin || $n >= $field_begin) + && (!defined $field_end || $n <= $field_end)) { + # e.g. [LAUGHTER-STORY] -> STORY; + $a =~ s:(|\-)^\[LAUGHTER-(.+)\](|\-)$:$1$2$3:i; + # $1 and $3 relate to preserving trailing "-" + $a =~ s:^\[(.+)/.+\](|\-)$:$1$2:; # e.g. [IT'N/ISN'T] -> IT'N ... note, + # 1st part may include partial-word stuff, which we process further below, + # e.g. [LEM[GUINI]-/LINGUINI] + # the (|\_) at the end is to accept and preserve trailing -'s. + $a =~ s:^(|\-)\[[^][]+\](.+)$:-$2:; # e.g. -[AN]Y , note \047 is quote; + # let the leading - be optional on input, as sometimes omitted. + $a =~ s:^(.+)\[[^][]+\](|\-)$:$1-:; # e.g. AB[SOLUTE]- -> AB-; + # let the trailing - be optional on input, as sometimes omitted. + $a =~ s:([^][]+)\[.+\]$:$1:; # e.g. EX[SPECIALLY]-/ESPECIALLY] -> EX- + # which is a mistake in the input. + $a =~ s:^\{(.+)\}$:$1:; # e.g. {YUPPIEDOM} -> YUPPIEDOM + $a =~ s:[A-Z]\[([^][])+\][A-Z]:$1-$3:i; # e.g. AMMU[N]IT- -> AMMU-IT- + $a =~ s:_\d$::; # e.g. THEM_1 -> THEM + } + $A[$n] = $a; + } + print join(" ", @A) . "\n"; +} diff --git a/egs/swbd/ASR/local/swbd1_prepare_dict.sh b/egs/swbd/ASR/local/swbd1_prepare_dict.sh new file mode 100755 index 000000000..eff5fb5f1 --- /dev/null +++ b/egs/swbd/ASR/local/swbd1_prepare_dict.sh @@ -0,0 +1,101 @@ +#!/usr/bin/env bash + +# Formatting the Mississippi State dictionary for use in Edinburgh. Differs +# from the one in Kaldi s5 recipe in that it uses lower-case --Arnab (Jan 2013) + +# To be run from one directory above this script. + +#check existing directories +[ $# != 0 ] && echo "Usage: local/swbd1_data_prep.sh" && exit 1 + +srcdir=. # This is where we downloaded some stuff.. +dir=./data/local/dict_nosp +mkdir -p $dir +srcdict=$srcdir/swb_ms98_transcriptions/sw-ms98-dict.text + +# assume swbd_p1_data_prep.sh was done already. +[ ! -f "$srcdict" ] && echo "$0: No such file $srcdict" && exit 1 + +cp $srcdict $dir/lexicon0.txt || exit 1 +chmod a+w $dir/lexicon0.txt +patch 0' | sort >$dir/lexicon1.txt || exit 1 + +cat $dir/lexicon1.txt | awk '{ for(n=2;n<=NF;n++){ phones[$n] = 1; }} END{for (p in phones) print p;}' | + grep -v sil >$dir/nonsilence_phones.txt || exit 1 + +( + echo sil + echo spn + echo nsn + echo lau +) >$dir/silence_phones.txt + +echo sil >$dir/optional_silence.txt + +# No "extra questions" in the input to this setup, as we don't +# have stress or tone. +echo -n >$dir/extra_questions.txt + +cp local/MSU_single_letter.txt $dir/ +# Add to the lexicon the silences, noises etc. +# Add single letter lexicon +# The original swbd lexicon does not have precise single letter lexicion +# e.g. it does not have entry of W +( + echo '!SIL SIL' + echo '[VOCALIZED-NOISE] spn' + echo '[NOISE] nsn' + echo '[LAUGHTER] lau' + echo ' spn' +) | + cat - $dir/lexicon1.txt $dir/MSU_single_letter.txt >$dir/lexicon2.txt || exit 1 + +# Map the words in the lexicon. That is-- for each word in the lexicon, we map it +# to a new written form. The transformations we do are: +# remove laughter markings, e.g. +# [LAUGHTER-STORY] -> STORY +# Remove partial-words, e.g. +# -[40]1K W AH N K EY +# becomes -1K +# and +# -[AN]Y IY +# becomes +# -Y +# -[A]B[OUT]- B +# becomes +# -B- +# Also, curly braces, which appear to be used for "nonstandard" +# words or non-words, are removed, e.g. +# {WOLMANIZED} W OW L M AX N AY Z D +# -> WOLMANIZED +# Also, mispronounced words, e.g. +# [YEAM/YEAH] Y AE M +# are changed to just e.g. YEAM, i.e. the orthography +# of the mispronounced version. +# Note-- this is only really to be used in training. The main practical +# reason is to avoid having tons of disambiguation symbols, which +# we otherwise would get because there are many partial words with +# the same phone sequences (most problematic: S). +# Also, map +# THEM_1 EH M -> THEM +# so that multiple pronunciations just have alternate entries +# in the lexicon. + +local/swbd1_map_words.pl -f 1 $dir/lexicon2.txt | sort -u \ + >$dir/lexicon3.txt || exit 1 + +python3 local/format_acronyms_dict.py -i $dir/lexicon3.txt -o $dir/lexicon4.txt \ + -L $dir/MSU_single_letter.txt -M $dir/acronyms_raw.map +cat $dir/acronyms_raw.map | sort -u >$dir/acronyms.map + +(echo 'i ay') | cat - $dir/lexicon4.txt | tr '[A-Z]' '[a-z]' | sort -u >$dir/lexicon5.txt + +pushd $dir >&/dev/null +ln -sf lexicon5.txt lexicon.txt # This is the final lexicon. +popd >&/dev/null +rm $dir/lexiconp.txt 2>/dev/null +echo Prepared input dictionary and phone-sets for Switchboard phase 1. diff --git a/egs/swbd/ASR/local/train_bpe_model.py b/egs/swbd/ASR/local/train_bpe_model.py new file mode 100755 index 000000000..9b4e28635 --- /dev/null +++ b/egs/swbd/ASR/local/train_bpe_model.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# You can install sentencepiece via: +# +# pip install sentencepiece +# +# Due to an issue reported in +# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 +# +# Please install a version >=0.1.96 + +import argparse +import shutil +from pathlib import Path + +import sentencepiece as spm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + The generated bpe.model is saved to this directory. + """, + ) + + parser.add_argument( + "--transcript", + type=str, + help="Training transcript.", + ) + + parser.add_argument( + "--vocab-size", + type=int, + help="Vocabulary size for BPE training", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + vocab_size = args.vocab_size + lang_dir = Path(args.lang_dir) + + model_type = "unigram" + + model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" + train_text = args.transcript + character_coverage = 1.0 + input_sentence_size = 100000000 + + user_defined_symbols = ["", ""] + unk_id = len(user_defined_symbols) + # Note: unk_id is fixed to 2. + # If you change it, you should also change other + # places that are using it. + + user_defined_symbols += ["[LAUGHTER]", "[NOISE]", "[VOCALIZED-NOISE]"] + + model_file = Path(model_prefix + ".model") + if not model_file.is_file(): + spm.SentencePieceTrainer.train( + input=train_text, + vocab_size=vocab_size, + model_type=model_type, + model_prefix=model_prefix, + input_sentence_size=input_sentence_size, + character_coverage=character_coverage, + user_defined_symbols=user_defined_symbols, + unk_id=unk_id, + bos_id=-1, + eos_id=-1, + ) + else: + print(f"{model_file} exists - skipping") + return + + shutil.copyfile(model_file, f"{lang_dir}/bpe.model") + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/local/validate_bpe_lexicon.py b/egs/swbd/ASR/local/validate_bpe_lexicon.py new file mode 120000 index 000000000..721bb48e7 --- /dev/null +++ b/egs/swbd/ASR/local/validate_bpe_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/swbd/ASR/prepare.sh b/egs/swbd/ASR/prepare.sh new file mode 100755 index 000000000..47d12613b --- /dev/null +++ b/egs/swbd/ASR/prepare.sh @@ -0,0 +1,463 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +nj=15 +stage=-1 +stop_stage=100 + +# We assume dl_dir (download dir) contains the following +# directories and files. Most of them can't be downloaded automatically +# as they are not publically available and require a license purchased +# from the LDC. +# +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech + +dl_dir=./download +# swbd1_dir="/export/corpora3/LDC/LDC97S62" +swbd1_dir=./download/LDC97S62/ + +# eval2000_dir contains the following files and directories +# downloaded from LDC website: +# - LDC2002S09 +# - hub5e_00 +# - LDC2002T43 +# - reference +eval2000_dir="/export/corpora2/LDC/eval2000" + +rt03_dir="/export/corpora/LDC/LDC2007S10" +fisher_dir="/export/corpora3/LDC/LDC2004T19" + +. shared/parse_options.sh || exit 1 + +# vocab size for sentence piece models. +# It will generate data/lang_bpe_xxx, +# data/lang_bpe_yyy if the array contains xxx, yyy +vocab_sizes=( + # 5000 + # 2000 + 1000 + 500 +) + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "swbd1_dir: $swbd1_dir" +log "eval2000_dir: $eval2000_dir" +log "rt03_dir: $rt03_dir" + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare SwitchBoard manifest" + # We assume that you have downloaded the SwitchBoard corpus + # to respective dirs + mkdir -p data/manifests + if [ ! -e data/manifests/.swbd.done ]; then + lhotse prepare switchboard --absolute-paths 1 --omit-silence $swbd1_dir data/manifests/swbd + ./local/normalize_and_filter_supervisions.py \ + data/manifests/swbd/swbd_supervisions_all.jsonl.gz \ + data/manifests/swbd/swbd_supervisions_all_norm.jsonl.gz + mv data/manifests/swbd/swbd_supervisions_all.jsonl.gz data/manifests/swbd/swbd_supervisions_orig.jsonl.gz + mv data/manifests/swbd/swbd_supervisions_all_norm.jsonl.gz data/manifests/swbd/swbd_supervisions_all.jsonl.gz + + lhotse cut simple \ + -r data/manifests/swbd/swbd_recordings_all.jsonl.gz \ + -s data/manifests/swbd/swbd_supervisions_all.jsonl.gz \ + data/manifests/swbd/swbd_train_all.jsonl.gz + lhotse cut trim-to-supervisions \ + --discard-overlapping \ + --discard-extra-channels \ + data/manifests/swbd/swbd_train_all.jsonl.gz \ + data/manifests/swbd/swbd_train_all_trimmed.jsonl.gz + + num_splits=16 + mkdir -p data/manifests/swbd_split${num_splits} + lhotse split ${num_splits} \ + data/manifests/swbd/swbd_train_all_trimmed.jsonl.gz \ + data/manifests/swbd_split${num_splits} + + lhotse prepare eval2000 --absolute-paths 1 $eval2000_dir data/manifests/eval2000 + ./local/normalize_eval2000.py \ + data/manifests/eval2000/eval2000_supervisions_unnorm.jsonl.gz \ + data/manifests/eval2000/eval2000_supervisions_all.jsonl.gz + + lhotse cut simple \ + -r data/manifests/eval2000/eval2000_recordings_all.jsonl.gz \ + -s data/manifests/eval2000/eval2000_supervisions_all.jsonl.gz \ + data/manifests/eval2000/eval2000_cuts_all.jsonl.gz + + lhotse cut trim-to-supervisions \ + --discard-overlapping \ + --discard-extra-channels \ + data/manifests/eval2000/eval2000_cuts_all.jsonl.gz \ + data/manifests/eval2000/eval2000_cuts_all_trimmed.jsonl.gz + + sed -e 's:((:(:' -e 's:::g' -e 's:::g' \ + $eval2000_dir/LDC2002T43/reference/hub5e00.english.000405.stm > data/manifests/eval2000/stm + cp $eval2000_dir/LDC2002T43/reference/en20000405_hub5.glm $dir/glm + + # ./local/rt03_data_prep.sh $rt03_dir + + # normalize eval2000 and rt03 texts by + # 1) convert upper to lower + # 2) remove tags (%AH) (%HESITATION) (%UH) + # 3) remove + # 4) remove "(" or ")" + # for x in rt03; do + # cp data/local/${x}/text data/local/${x}/text.org + # paste -d "" \ + # <(cut -f 1 -d" " data/local/${x}/text.org) \ + # <(awk '{$1=""; print tolower($0)}' data/local/${x}/text.org | perl -pe 's| \(\%.*\)||g' | perl -pe 's| \<.*\>||g' | sed -e "s/(//g" -e "s/)//g") | + # sed -e 's/\s\+/ /g' >data/local/${x}/text + # rm data/local/${x}/text.org + # done + + # lhotse fix data/manifests_rt03/swbd_recordings_rt03.jsonl.gz data/manifests_rt03/swbd_supervisions_rt03.jsonl.gz data/manifests + + touch data/manifests/.swbd.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to $dl_dir/musan + mkdir -p data/manifests + if [ ! -e data/manifests/.musan.done ]; then + lhotse prepare musan $dl_dir/musan data/manifests + touch data/manifests/.musan.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3 I: Compute fbank for SwitchBoard" + if [ ! -e data/fbank/.swbd.done ]; then + mkdir -p data/fbank/swbd_split${num_splits}/ + for index in $(seq 1 16); do + ./local/compute_fbank_swbd.py --split-index ${index} & + done + wait + pieces=$(find data/fbank/swbd_split${num_splits} -name "swbd_cuts_all.*.jsonl.gz") + lhotse combine $pieces data/fbank/swbd_cuts_all.jsonl.gz + touch data/fbank/.swbd.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3 II: Compute fbank for eval2000" + if [ ! -e data/fbank/.eval2000.done ]; then + mkdir -p data/fbank/eval2000/ + ./local/compute_fbank_eval2000.py + touch data/fbank/.eval2000.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank for musan" + mkdir -p data/fbank + if [ ! -e data/fbank/.musan.done ]; then + ./local/compute_fbank_musan.py + touch data/fbank/.musan.done + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Prepare phone based lang" + lang_dir=data/lang_phone + mkdir -p $lang_dir + + if ! which jq; then + echo "This script is intended to be used with jq but you have not installed jq + Note: in Linux, you can install jq with the following command: + 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 + 2. chmod +x ./jq + 3. cp jq /usr/bin" && exit 1 + fi + if [ ! -f $lang_dir/text ] || [ ! -s $lang_dir/text ]; then + log "Prepare text." + gunzip -c data/manifests/swbd/swbd_supervisions_all.jsonl.gz \ + | jq '.text' | sed 's/"//g' > $lang_dir/text + fi + + log "Prepare dict" + ./local/swbd1_prepare_dict.sh + cut -f 2- -d" " $lang_dir/text >${lang_dir}/input.txt + # [noise] nsn + # !sil sil + # spn + cat data/local/dict_nosp/lexicon.txt | sed 's/-//g' | sed 's/\[vocalizednoise\]/\[vocalized-noise\]/g' | + sort | uniq >$lang_dir/lexicon_lower.txt + + cat $lang_dir/lexicon_lower.txt | tr a-z A-Z > $lang_dir/lexicon.txt + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang.py --lang-dir $lang_dir + fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Prepare BPE based lang" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + mkdir -p $lang_dir + # We reuse words.txt from phone based lexicon + # so that the two can share G.pt later. + cp data/lang_phone/words.txt $lang_dir + + if [ ! -f $lang_dir/transcript_words.txt ]; then + log "Generate data for BPE training" + + cat data/lang_phone/text | cut -d " " -f 2- >$lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py --lang-dir $lang_dir + + log "Validating $lang_dir/lexicon.txt" + ./local/validate_bpe_lexicon.py \ + --lexicon $lang_dir/lexicon.txt \ + --bpe-model $lang_dir/bpe.model + fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi + done +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Prepare bigram token-level P for MMI training" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + + if [ ! -f $lang_dir/transcript_tokens.txt ]; then + ./local/convert_transcript_words_to_tokens.py \ + --lexicon $lang_dir/lexicon.txt \ + --transcript $lang_dir/transcript_words.txt \ + --oov "" \ + >$lang_dir/transcript_tokens.txt + fi + + if [ ! -f $lang_dir/P.arpa ]; then + ./shared/make_kn_lm.py \ + -ngram-order 2 \ + -text $lang_dir/transcript_tokens.txt \ + -lm $lang_dir/P.arpa + fi + + if [ ! -f $lang_dir/P.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="$lang_dir/tokens.txt" \ + --disambig-symbol='#0' \ + --max-order=2 \ + $lang_dir/P.arpa >$lang_dir/P.fst.txt + fi + done +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Prepare G" + lang_dir=data/lang_phone + # We assume you have install kaldilm, if not, please install + # it using: pip install kaldilm + + mkdir -p data/lm + if [ ! -f data/lm/G_3_gram.fst.txt ]; then + # It is used in building HLG + ./shared/make_kn_lm.py \ + -ngram-order 3 \ + -text ${lang_dir}/input.txt \ + -lm data/lm/3-gram.arpa + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + data/lm/3-gram.arpa >data/lm/G_3_gram.fst.txt + fi + + if [ ! -f data/lm/G_4_gram.fst.txt ]; then + # It is used for LM rescoring + ./shared/make_kn_lm.py \ + -ngram-order 4 \ + -text ${lang_dir}/input.txt \ + -lm data/lm/4-gram.arpa + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=4 \ + data/lm/4-gram.arpa >data/lm/G_4_gram.fst.txt + fi +fi + +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Compile HLG" + ./local/compile_hlg.py --lang-dir data/lang_phone + + # Note If ./local/compile_hlg.py throws OOM, + # please switch to the following command + # + # ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_hlg.py --lang-dir $lang_dir + + # Note If ./local/compile_hlg.py throws OOM, + # please switch to the following command + # + # ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir + done +fi + +# Compile LG for RNN-T fast_beam_search decoding +if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then + log "Stage 10: Compile LG" + ./local/compile_lg.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_lg.py --lang-dir $lang_dir + done +fi + +if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then + log "Stage 11: Generate LM training data" + + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + lang_dir=data/lang_bpe_${vocab_size} + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + if [ ! -f $out_dir/train.txt ]; then + tail -n 250000 data/lang_phone/input.txt > $out_dir/train.txt + fi + + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data data/lang_phone/input.txt \ + --lm-archive $out_dir/lm_data.pt + done +fi + +if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then + log "Stage 12: Generate LM validation data" + + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + if [ ! -f $out_dir/valid.txt ]; then + head -n 14332 data/lang_phone/input.txt > $out_dir/valid.txt + fi + + lang_dir=data/lang_bpe_${vocab_size} + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $out_dir/valid.txt \ + --lm-archive $out_dir/lm_data-valid.pt + done +fi + +if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then + log "Stage 13: Generate LM test data" + testsets=(eval2000) + + for testset in ${testsets[@]}; do + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + if [ ! -f $out_dir/${testset}.txt ]; then + gunzip -c data/manifests/${testset}/eval2000_supervisions_all.jsonl.gz \ + | jq '.text' | sed 's/"//g' > $out_dir/${testset}.txt + fi + + lang_dir=data/lang_bpe_${vocab_size} + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $out_dir/${testset}.txt \ + --lm-archive $out_dir/lm_data-${testset}.pt + done + done +fi + +if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then + log "Stage 14: Sort LM training data" + testsets=(eval2000) + # Sort LM training data by sentence length in descending order + # for ease of training. + # + # Sentence length equals to the number of BPE tokens + # in a sentence. + + for vocab_size in ${vocab_sizes[@]}; do + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data.pt \ + --out-lm-data $out_dir/sorted_lm_data.pt \ + --out-statistics $out_dir/statistics.txt + for testset in ${testsets[@]}; do + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data-${testset}.pt \ + --out-lm-data $out_dir/sorted_lm_data-${testset}.pt \ + --out-statistics $out_dir/statistics-test-${testset}.txt + done + done +fi diff --git a/egs/swbd/ASR/shared b/egs/swbd/ASR/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/swbd/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/egs/swbd/ASR/utils/filter_scp.pl b/egs/swbd/ASR/utils/filter_scp.pl new file mode 100755 index 000000000..b76d37f41 --- /dev/null +++ b/egs/swbd/ASR/utils/filter_scp.pl @@ -0,0 +1,87 @@ +#!/usr/bin/env perl +# Copyright 2010-2012 Microsoft Corporation +# Johns Hopkins University (author: Daniel Povey) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# This script takes a list of utterance-ids or any file whose first field +# of each line is an utterance-id, and filters an scp +# file (or any file whose "n-th" field is an utterance id), printing +# out only those lines whose "n-th" field is in id_list. The index of +# the "n-th" field is 1, by default, but can be changed by using +# the -f switch + +$exclude = 0; +$field = 1; +$shifted = 0; + +do { + $shifted=0; + if ($ARGV[0] eq "--exclude") { + $exclude = 1; + shift @ARGV; + $shifted=1; + } + if ($ARGV[0] eq "-f") { + $field = $ARGV[1]; + shift @ARGV; shift @ARGV; + $shifted=1 + } +} while ($shifted); + +if(@ARGV < 1 || @ARGV > 2) { + die "Usage: filter_scp.pl [--exclude] [-f ] id_list [in.scp] > out.scp \n" . + "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" . + "Note: only the first field of each line in id_list matters. With --exclude, prints\n" . + "only the lines that were *not* in id_list.\n" . + "Caution: previously, the -f option was interpreted as a zero-based field index.\n" . + "If your older scripts (written before Oct 2014) stopped working and you used the\n" . + "-f option, add 1 to the argument.\n" . + "See also: utils/filter_scp.pl .\n"; +} + + +$idlist = shift @ARGV; +open(F, "<$idlist") || die "Could not open id-list file $idlist"; +while() { + @A = split; + @A>=1 || die "Invalid id-list file line $_"; + $seen{$A[0]} = 1; +} + +if ($field == 1) { # Treat this as special case, since it is common. + while(<>) { + $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field."; + # $1 is what we filter on. + if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) { + print $_; + } + } +} else { + while(<>) { + @A = split; + @A > 0 || die "Invalid scp file line $_"; + @A >= $field || die "Invalid scp file line $_"; + if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) { + print $_; + } + } +} + +# tests: +# the following should print "foo 1" +# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo) +# the following should print "bar 2". +# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2) diff --git a/egs/swbd/ASR/utils/fix_data_dir.sh b/egs/swbd/ASR/utils/fix_data_dir.sh new file mode 100755 index 000000000..ca0972ca8 --- /dev/null +++ b/egs/swbd/ASR/utils/fix_data_dir.sh @@ -0,0 +1,197 @@ +#!/bin/bash + +# This script makes sure that only the segments present in +# all of "feats.scp", "wav.scp" [if present], segments [if present] +# text, and utt2spk are present in any of them. +# It puts the original contents of data-dir into +# data-dir/.backup + +cmd="$@" + +utt_extra_files= +spk_extra_files= + +. utils/parse_options.sh + +if [ $# != 1 ]; then + echo "Usage: utils/data/fix_data_dir.sh " + echo "e.g.: utils/data/fix_data_dir.sh data/train" + echo "This script helps ensure that the various files in a data directory" + echo "are correctly sorted and filtered, for example removing utterances" + echo "that have no features (if feats.scp is present)" + exit 1 +fi + +data=$1 + +if [ -f $data/images.scp ]; then + image/fix_data_dir.sh $cmd + exit $? +fi + +mkdir -p $data/.backup + +[ ! -d $data ] && echo "$0: no such directory $data" && exit 1; + +[ ! -f $data/utt2spk ] && echo "$0: no such file $data/utt2spk" && exit 1; + +set -e -o pipefail -u + +tmpdir=$(mktemp -d /tmp/kaldi.XXXX); +trap 'rm -rf "$tmpdir"' EXIT HUP INT PIPE TERM + +export LC_ALL=C + +function check_sorted { + file=$1 + sort -k1,1 -u <$file >$file.tmp + if ! cmp -s $file $file.tmp; then + echo "$0: file $1 is not in sorted order or not unique, sorting it" + mv $file.tmp $file + else + rm $file.tmp + fi +} + +for x in utt2spk spk2utt feats.scp text segments wav.scp cmvn.scp vad.scp \ + reco2file_and_channel spk2gender utt2lang utt2uniq utt2dur reco2dur utt2num_frames; do + if [ -f $data/$x ]; then + cp $data/$x $data/.backup/$x + check_sorted $data/$x + fi +done + + +function filter_file { + filter=$1 + file_to_filter=$2 + cp $file_to_filter ${file_to_filter}.tmp + utils/filter_scp.pl $filter ${file_to_filter}.tmp > $file_to_filter + if ! cmp ${file_to_filter}.tmp $file_to_filter >&/dev/null; then + length1=$(cat ${file_to_filter}.tmp | wc -l) + length2=$(cat ${file_to_filter} | wc -l) + if [ $length1 -ne $length2 ]; then + echo "$0: filtered $file_to_filter from $length1 to $length2 lines based on filter $filter." + fi + fi + rm $file_to_filter.tmp +} + +function filter_recordings { + # We call this once before the stage when we filter on utterance-id, and once + # after. + + if [ -f $data/segments ]; then + # We have a segments file -> we need to filter this and the file wav.scp, and + # reco2file_and_utt, if it exists, to make sure they have the same list of + # recording-ids. + + if [ ! -f $data/wav.scp ]; then + echo "$0: $data/segments exists but not $data/wav.scp" + exit 1; + fi + awk '{print $2}' < $data/segments | sort | uniq > $tmpdir/recordings + n1=$(cat $tmpdir/recordings | wc -l) + [ ! -s $tmpdir/recordings ] && \ + echo "Empty list of recordings (bad file $data/segments)?" && exit 1; + utils/filter_scp.pl $data/wav.scp $tmpdir/recordings > $tmpdir/recordings.tmp + mv $tmpdir/recordings.tmp $tmpdir/recordings + + + cp $data/segments{,.tmp}; awk '{print $2, $1, $3, $4}' <$data/segments.tmp >$data/segments + filter_file $tmpdir/recordings $data/segments + cp $data/segments{,.tmp}; awk '{print $2, $1, $3, $4}' <$data/segments.tmp >$data/segments + rm $data/segments.tmp + + filter_file $tmpdir/recordings $data/wav.scp + [ -f $data/reco2file_and_channel ] && filter_file $tmpdir/recordings $data/reco2file_and_channel + [ -f $data/reco2dur ] && filter_file $tmpdir/recordings $data/reco2dur + true + fi +} + +function filter_speakers { + # throughout this program, we regard utt2spk as primary and spk2utt as derived, so... + utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt + + cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers + for s in cmvn.scp spk2gender; do + f=$data/$s + if [ -f $f ]; then + filter_file $f $tmpdir/speakers + fi + done + + filter_file $tmpdir/speakers $data/spk2utt + utils/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk + + for s in cmvn.scp spk2gender $spk_extra_files; do + f=$data/$s + if [ -f $f ]; then + filter_file $tmpdir/speakers $f + fi + done +} + +function filter_utts { + cat $data/utt2spk | awk '{print $1}' > $tmpdir/utts + + ! cat $data/utt2spk | sort | cmp - $data/utt2spk && \ + echo "utt2spk is not in sorted order (fix this yourself)" && exit 1; + + ! cat $data/utt2spk | sort -k2 | cmp - $data/utt2spk && \ + echo "utt2spk is not in sorted order when sorted first on speaker-id " && \ + echo "(fix this by making speaker-ids prefixes of utt-ids)" && exit 1; + + ! cat $data/spk2utt | sort | cmp - $data/spk2utt && \ + echo "spk2utt is not in sorted order (fix this yourself)" && exit 1; + + if [ -f $data/utt2uniq ]; then + ! cat $data/utt2uniq | sort | cmp - $data/utt2uniq && \ + echo "utt2uniq is not in sorted order (fix this yourself)" && exit 1; + fi + + maybe_wav= + maybe_reco2dur= + [ ! -f $data/segments ] && maybe_wav=wav.scp # wav indexed by utts only if segments does not exist. + [ -s $data/reco2dur ] && [ ! -f $data/segments ] && maybe_reco2dur=reco2dur # reco2dur indexed by utts + for x in feats.scp text segments utt2lang $maybe_wav; do + if [ -f $data/$x ]; then + utils/filter_scp.pl $data/$x $tmpdir/utts > $tmpdir/utts.tmp + mv $tmpdir/utts.tmp $tmpdir/utts + fi + done + [ ! -s $tmpdir/utts ] && echo "fix_data_dir.sh: no utterances remained: not proceeding further." && \ + rm $tmpdir/utts && exit 1; + + + if [ -f $data/utt2spk ]; then + new_nutts=$(cat $tmpdir/utts | wc -l) + old_nutts=$(cat $data/utt2spk | wc -l) + if [ $new_nutts -ne $old_nutts ]; then + echo "fix_data_dir.sh: kept $new_nutts utterances out of $old_nutts" + else + echo "fix_data_dir.sh: kept all $old_nutts utterances." + fi + fi + + for x in utt2spk utt2uniq feats.scp vad.scp text segments utt2lang utt2dur utt2num_frames $maybe_wav $maybe_reco2dur $utt_extra_files; do + if [ -f $data/$x ]; then + cp $data/$x $data/.backup/$x + if ! cmp -s $data/$x <( utils/filter_scp.pl $tmpdir/utts $data/$x ) ; then + utils/filter_scp.pl $tmpdir/utts $data/.backup/$x > $data/$x + fi + fi + done + +} + +filter_recordings +filter_speakers +filter_utts +filter_speakers +filter_recordings + +utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt + +echo "fix_data_dir.sh: old files are kept in $data/.backup" diff --git a/egs/swbd/ASR/utils/parse_options.sh b/egs/swbd/ASR/utils/parse_options.sh new file mode 100755 index 000000000..34476fdb3 --- /dev/null +++ b/egs/swbd/ASR/utils/parse_options.sh @@ -0,0 +1,97 @@ +#!/bin/bash + +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey); +# Arnab Ghoshal, Karel Vesely + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# Parse command-line options. +# To be sourced by another script (as in ". parse_options.sh"). +# Option format is: --option-name arg +# and shell variable "option_name" gets set to value "arg." +# The exception is --help, which takes no arguments, but prints the +# $help_message variable (if defined). + + +### +### The --config file options have lower priority to command line +### options, so we need to import them first... +### + +# Now import all the configs specified by command-line, in left-to-right order +for ((argpos=1; argpos<$#; argpos++)); do + if [ "${!argpos}" == "--config" ]; then + argpos_plus1=$((argpos+1)) + config=${!argpos_plus1} + [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 + . $config # source the config file. + fi +done + + +### +### No we process the command line options +### +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + # If the enclosing script is called with --help option, print the help + # message and exit. Scripts should put help messages in $help_message + --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; + else printf "$help_message\n" 1>&2 ; fi; + exit 0 ;; + --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" + exit 1 ;; + # If the first command-line argument begins with "--" (e.g. --foo-bar), + # then work out the variable name as $name, which will equal "foo_bar". + --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; + # Next we test whether the variable in question is undefned-- if so it's + # an invalid option and we die. Note: $0 evaluates to the name of the + # enclosing script. + # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar + # is undefined. We then have to wrap this test inside "eval" because + # foo_bar is itself inside a variable ($name). + eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + + oldval="`eval echo \\$$name`"; + # Work out whether we seem to be expecting a Boolean argument. + if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval $name=\"$2\"; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + *) break; + esac +done + + +# Check for an empty argument to the --cmd option, which can easily occur as a +# result of scripting errors. +[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; + + +true; # so this script returns exit code 0. diff --git a/egs/swbd/ASR/utils/spk2utt_to_utt2spk.pl b/egs/swbd/ASR/utils/spk2utt_to_utt2spk.pl new file mode 100755 index 000000000..23992f25d --- /dev/null +++ b/egs/swbd/ASR/utils/spk2utt_to_utt2spk.pl @@ -0,0 +1,27 @@ +#!/usr/bin/env perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +while(<>){ + @A = split(" ", $_); + @A > 1 || die "Invalid line in spk2utt file: $_"; + $s = shift @A; + foreach $u ( @A ) { + print "$u $s\n"; + } +} + + diff --git a/egs/swbd/ASR/utils/utt2spk_to_spk2utt.pl b/egs/swbd/ASR/utils/utt2spk_to_spk2utt.pl new file mode 100755 index 000000000..6e0e438ca --- /dev/null +++ b/egs/swbd/ASR/utils/utt2spk_to_spk2utt.pl @@ -0,0 +1,38 @@ +#!/usr/bin/env perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +# converts an utt2spk file to a spk2utt file. +# Takes input from the stdin or from a file argument; +# output goes to the standard out. + +if ( @ARGV > 1 ) { + die "Usage: utt2spk_to_spk2utt.pl [ utt2spk ] > spk2utt"; +} + +while(<>){ + @A = split(" ", $_); + @A == 2 || die "Invalid line in utt2spk file: $_"; + ($u,$s) = @A; + if(!$seen_spk{$s}) { + $seen_spk{$s} = 1; + push @spklist, $s; + } + push (@{$spk_hash{$s}}, "$u"); +} +foreach $s (@spklist) { + $l = join(' ',@{$spk_hash{$s}}); + print "$s $l\n"; +} diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py index 417515968..d03970265 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py @@ -915,7 +915,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py index d80e0147c..aee3972cd 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py @@ -69,7 +69,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer -from icefall import diagnostics, byte_encode, tokenize_by_CJK_char +from icefall import byte_encode, diagnostics, tokenize_by_CJK_char from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( @@ -1018,7 +1018,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/tedlium3/ASR/conformer_ctc2/local b/egs/tedlium3/ASR/conformer_ctc2/local new file mode 120000 index 000000000..c820590c5 --- /dev/null +++ b/egs/tedlium3/ASR/conformer_ctc2/local @@ -0,0 +1 @@ +../local \ No newline at end of file diff --git a/egs/tedlium3/ASR/conformer_ctc2/train.py b/egs/tedlium3/ASR/conformer_ctc2/train.py index 42e4c010a..fc3e3b2d9 100755 --- a/egs/tedlium3/ASR/conformer_ctc2/train.py +++ b/egs/tedlium3/ASR/conformer_ctc2/train.py @@ -905,7 +905,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/local b/egs/tedlium3/ASR/pruned_transducer_stateless/local new file mode 120000 index 000000000..c820590c5 --- /dev/null +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/local @@ -0,0 +1 @@ +../local \ No newline at end of file diff --git a/egs/tedlium3/ASR/transducer_stateless/local b/egs/tedlium3/ASR/transducer_stateless/local new file mode 120000 index 000000000..c820590c5 --- /dev/null +++ b/egs/tedlium3/ASR/transducer_stateless/local @@ -0,0 +1 @@ +../local \ No newline at end of file diff --git a/egs/tedlium3/ASR/zipformer/local b/egs/tedlium3/ASR/zipformer/local new file mode 120000 index 000000000..c820590c5 --- /dev/null +++ b/egs/tedlium3/ASR/zipformer/local @@ -0,0 +1 @@ +../local \ No newline at end of file diff --git a/egs/tedlium3/ASR/zipformer/train.py b/egs/tedlium3/ASR/zipformer/train.py index 9271c8438..33d03908c 100755 --- a/egs/tedlium3/ASR/zipformer/train.py +++ b/egs/tedlium3/ASR/zipformer/train.py @@ -1126,7 +1126,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py index 760fad974..140b1d37f 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py @@ -315,6 +315,7 @@ def export_decoder_model_onnx( vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py index e703100a9..82bc882bd 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py @@ -886,7 +886,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py index a46ff5a07..2d46eede1 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py @@ -258,6 +258,7 @@ def main(): encoder_session = ort.InferenceSession( args.onnx_encoder_filename, sess_options=options, + providers=["CPUExecutionProvider"], ) test_encoder(model, encoder_session) @@ -265,6 +266,7 @@ def main(): decoder_session = ort.InferenceSession( args.onnx_decoder_filename, sess_options=options, + providers=["CPUExecutionProvider"], ) test_decoder(model, decoder_session) @@ -272,14 +274,17 @@ def main(): joiner_session = ort.InferenceSession( args.onnx_joiner_filename, sess_options=options, + providers=["CPUExecutionProvider"], ) joiner_encoder_proj_session = ort.InferenceSession( args.onnx_joiner_encoder_proj_filename, sess_options=options, + providers=["CPUExecutionProvider"], ) joiner_decoder_proj_session = ort.InferenceSession( args.onnx_joiner_decoder_proj_filename, sess_options=options, + providers=["CPUExecutionProvider"], ) test_joiner( model, diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py index 48b347b64..49977e01b 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py @@ -851,7 +851,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py index 9a926d7e5..921766ad4 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py @@ -404,6 +404,7 @@ def export_decoder_model_onnx( vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx.py index 68c7cc352..037c7adf1 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx.py @@ -335,6 +335,7 @@ def export_decoder_model_onnx( vocab_size = decoder_model.decoder.vocab_size y = torch.zeros(10, context_size, dtype=torch.int64) + decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, y, diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py index facfc2258..c31db6859 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py @@ -139,6 +139,7 @@ class OnnxModel: self.encoder = ort.InferenceSession( encoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) self.init_encoder_states() @@ -186,6 +187,7 @@ class OnnxModel: self.decoder = ort.InferenceSession( decoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) decoder_meta = self.decoder.get_modelmeta().custom_metadata_map @@ -199,6 +201,7 @@ class OnnxModel: self.joiner = ort.InferenceSession( joiner_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) joiner_meta = self.joiner.get_modelmeta().custom_metadata_map diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py index e7c8b4556..c784853ee 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py @@ -158,12 +158,14 @@ class OnnxModel: self.encoder = ort.InferenceSession( encoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) def init_decoder(self, decoder_model_filename: str): self.decoder = ort.InferenceSession( decoder_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) decoder_meta = self.decoder.get_modelmeta().custom_metadata_map @@ -177,6 +179,7 @@ class OnnxModel: self.joiner = ort.InferenceSession( joiner_model_filename, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) joiner_meta = self.joiner.get_modelmeta().custom_metadata_map diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py index 8e1b12dba..931e699d9 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py @@ -985,7 +985,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/wenetspeech/ASR/zipformer/train.py b/egs/wenetspeech/ASR/zipformer/train.py index 83dbfa22f..b1557dedb 100755 --- a/egs/wenetspeech/ASR/zipformer/train.py +++ b/egs/wenetspeech/ASR/zipformer/train.py @@ -1128,7 +1128,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py index 5b5ac17be..a6fa46b17 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py @@ -1001,7 +1001,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py index f8dd7b287..8c53972fd 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py @@ -993,7 +993,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/yesno/ASR/local/prepare_lang_fst.py b/egs/yesno/ASR/local/prepare_lang_fst.py new file mode 120000 index 000000000..c5787c534 --- /dev/null +++ b/egs/yesno/ASR/local/prepare_lang_fst.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_fst.py \ No newline at end of file diff --git a/egs/yesno/ASR/prepare.sh b/egs/yesno/ASR/prepare.sh index d4ef8d601..41db0cf7c 100755 --- a/egs/yesno/ASR/prepare.sh +++ b/egs/yesno/ASR/prepare.sh @@ -60,6 +60,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then ) > $lang_dir/lexicon.txt ./local/prepare_lang.py + ./local/prepare_lang_fst.py --lang-dir ./data/lang_phone --has-silence 1 fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then diff --git a/egs/yesno/ASR/tdnn/jit_pretrained.py b/egs/yesno/ASR/tdnn/jit_pretrained.py index 84390fca5..7581ecb83 100755 --- a/egs/yesno/ASR/tdnn/jit_pretrained.py +++ b/egs/yesno/ASR/tdnn/jit_pretrained.py @@ -156,7 +156,6 @@ def main(): features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - # Note: We don't use key padding mask for attention during decoding nnet_output = model(features) batch_size = nnet_output.shape[0] diff --git a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py new file mode 100755 index 000000000..ff8c742af --- /dev/null +++ b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This file shows how to use a torchscript model for decoding with H +on CPU using OpenFST and decoders from kaldi. + +Usage: + + ./tdnn/jit_pretrained_decode_with_H.py \ + --nn-model ./tdnn/exp/cpu_jit.pt \ + --H ./data/lang_phone/H.fst \ + --tokens ./data/lang_phone/tokens.txt \ + ./download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + ./download/waves_yesno/0_0_1_0_0_0_1_0.wav \ + ./download/waves_yesno/0_0_1_0_0_1_1_1.wav + +Note that to generate ./tdnn/exp/cpu_jit.pt, +you can use ./export.py --jit 1 +""" + +import argparse +import logging +import math +from typing import Dict, List + +import kaldifeat +import kaldifst +import torch +import torchaudio +from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="""Path to the torchscript model. + You can use ./tdnn/export.py --jit 1 + to obtain it + """, + ) + + parser.add_argument( + "--tokens", + type=str, + required=True, + help="Path to tokens.txt", + ) + + parser.add_argument("--H", type=str, required=True, help="Path to H.fst") + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. ", + ) + + return parser + + +def read_tokens(tokens_txt: str) -> Dict[int, str]: + id2token = dict() + with open(tokens_txt, encoding="utf-8") as f: + for line in f: + token, idx = line.strip().split() + id2token[int(idx)] = token + + return id2token + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + if sample_rate != expected_sample_rate: + wave = torchaudio.functional.resample( + wave, + orig_freq=sample_rate, + new_freq=expected_sample_rate, + ) + + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + nnet_output: torch.Tensor, + H: kaldifst, + id2token: Dict[int, str], +) -> List[str]: + decodable = DecodableCtc(nnet_output) + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(H, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + print(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + print(f"failed to get linear symbol sequence for {filename}") + return [""] + + # are shifted by 1 during graph construction + hyps = [id2token[i - 1] for i in osymbols_out if id2token[i - 1] != "SIL"] + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + logging.info("Loading torchscript model") + model = torch.jit.load(args.nn_model) + model.eval() + model.to(device) + + logging.info(f"Loading H from {args.H}") + H = kaldifst.StdVectorFst.read(args.H) + + sample_rate = 8000 + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 23 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, expected_sample_rate=sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + nnet_output = model(features) + + id2token = read_tokens(args.tokens) + + hyps = [] + for i in range(nnet_output.shape[0]): + hyp = decode( + filename=args.sound_files[0], + nnet_output=nnet_output[i], + H=H, + id2token=id2token, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py new file mode 100755 index 000000000..05ba74f9a --- /dev/null +++ b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This file shows how to use a torchscript model for decoding with HL +on CPU using OpenFST and decoders from kaldi. + +Usage: + + ./tdnn/jit_pretrained_decode_with_HL.py \ + --nn-model ./tdnn/exp/cpu_jit.pt \ + --HL ./data/lang_phone/HL.fst \ + --words ./data/lang_phone/words.txt \ + ./download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + ./download/waves_yesno/0_0_1_0_0_0_1_0.wav \ + ./download/waves_yesno/0_0_1_0_0_1_1_1.wav + +Note that to generate ./tdnn/exp/cpu_jit.pt, +you can use ./export.py --jit 1 +""" + +import argparse +import logging +import math +from typing import Dict, List + +import kaldifeat +import kaldifst +import torch +import torchaudio +from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="""Path to the torchscript model. + You can use ./tdnn/export.py --jit 1 + to obtain it + """, + ) + + parser.add_argument( + "--words", + type=str, + required=True, + help="Path to words.txt", + ) + + parser.add_argument("--HL", type=str, required=True, help="Path to HL.fst") + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. ", + ) + + return parser + + +def read_words(words_txt: str) -> Dict[int, str]: + id2word = dict() + with open(words_txt, encoding="utf-8") as f: + for line in f: + word, idx = line.strip().split() + id2word[int(idx)] = word + + return id2word + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + if sample_rate != expected_sample_rate: + wave = torchaudio.functional.resample( + wave, + orig_freq=sample_rate, + new_freq=expected_sample_rate, + ) + + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + nnet_output: torch.Tensor, + HL: kaldifst, + id2word: Dict[int, str], +) -> List[str]: + decodable = DecodableCtc(nnet_output) + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(HL, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + print(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + print(f"failed to get linear symbol sequence for {filename}") + return [""] + + hyps = [id2word[i] for i in osymbols_out if id2word[i] != ""] + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + logging.info("Loading torchscript model") + model = torch.jit.load(args.nn_model) + model.eval() + model.to(device) + + logging.info(f"Loading HL from {args.HL}") + HL = kaldifst.StdVectorFst.read(args.HL) + + sample_rate = 8000 + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 23 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, expected_sample_rate=sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + nnet_output = model(features) + + id2word = read_words(args.words) + + hyps = [] + for i in range(nnet_output.shape[0]): + hyp = decode( + filename=args.sound_files[0], + nnet_output=nnet_output[i], + HL=HL, + id2word=id2word, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/yesno/ASR/tdnn/onnx_pretrained.py b/egs/yesno/ASR/tdnn/onnx_pretrained.py index b23a2a381..72a1d69c8 100755 --- a/egs/yesno/ASR/tdnn/onnx_pretrained.py +++ b/egs/yesno/ASR/tdnn/onnx_pretrained.py @@ -54,6 +54,7 @@ class OnnxModel: self.model = ort.InferenceSession( nn_model, sess_options=self.session_opts, + providers=["CPUExecutionProvider"], ) meta = self.model.get_modelmeta().custom_metadata_map diff --git a/icefall/__init__.py b/icefall/__init__.py index 05e2b408c..b1e4313e9 100644 --- a/icefall/__init__.py +++ b/icefall/__init__.py @@ -1,12 +1,6 @@ # isort:skip_file -from . import ( - checkpoint, - decode, - dist, - env, - utils -) +from . import checkpoint, decode, dist, env, utils from .byte_utils import ( byte_decode, diff --git a/icefall/context_graph.py b/icefall/context_graph.py index 01836df04..0b7c42c0b 100644 --- a/icefall/context_graph.py +++ b/icefall/context_graph.py @@ -227,7 +227,6 @@ class ContextGraph: filename: Optional[str] = "", symbol_table: Optional[Dict[int, str]] = None, ) -> "Digraph": # noqa - """Visualize a ContextGraph via graphviz. Render ContextGraph as an image via graphviz, and return the Digraph object; diff --git a/icefall/ctc/.gitignore b/icefall/ctc/.gitignore new file mode 100644 index 000000000..8154cb57f --- /dev/null +++ b/icefall/ctc/.gitignore @@ -0,0 +1,2 @@ +*.pdf +*.gv diff --git a/icefall/ctc/README.md b/icefall/ctc/README.md new file mode 100644 index 000000000..0096bc096 --- /dev/null +++ b/icefall/ctc/README.md @@ -0,0 +1,17 @@ +# Introduction + +This folder uses [kaldifst][kaldifst] for graph construction +and decoders from [kaldi-decoder][kaldi-decoder] for CTC decoding. + +It supports only `CPU`. + +You can use + +```bash +pip install kaldifst kaldi-decoder +``` +to install the dependencies. + +[kaldi-decoder]: https://github.com/i2-fsa/kaldi-decoder +[kaldifst]: https://github.com/k2-fsa/kaldifst +[k2]: https://github.com/k2-fsa/k2 diff --git a/icefall/ctc/__init__.py b/icefall/ctc/__init__.py new file mode 100644 index 000000000..b546b31af --- /dev/null +++ b/icefall/ctc/__init__.py @@ -0,0 +1,6 @@ +from .prepare_lang import ( + Lexicon, + make_lexicon_fst_no_silence, + make_lexicon_fst_with_silence, +) +from .topo import add_disambig_self_loops, add_one, build_standard_ctc_topo diff --git a/icefall/ctc/prepare_lang.py b/icefall/ctc/prepare_lang.py new file mode 100644 index 000000000..4801b1beb --- /dev/null +++ b/icefall/ctc/prepare_lang.py @@ -0,0 +1,334 @@ +# Copyright 2023 Xiaomi Corp. (author: Fangjun Kuang) + +""" +The lang_dir should contain the following files: + - "lexicon_disambig.txt" + - "tokens.txt" + - "words.txt" +""" + +import math +from collections import defaultdict +from pathlib import Path +from typing import List, Tuple + +import kaldifst +import re + + +class Lexicon: + """Once constructed it is immutable""" + + def __init__( + self, + lang_dir: str, + disambig_pattern: str = re.compile(r"^#\d+$"), + ): + """ + Args: + lang_dir: + The path to the lang directory. We expect that it contains the + following files: + - lexicon_disambig.txt + - tokens.txt + - words.txt + + The format of the above files is described below. + + (1) lexicon_disambig.txt + + Each line in the lexicon_disambig.txt has the following format: + + word token1 token2 ... tokenN + + That is, the first field is the word, the remaining fields are + pronunciations of this word. Fields are separated by space(s). + + (2) tokens.txt + + Each line in tokens.txt has two fields separated by space(s): + + token ID + + The first field is the token symbol and the second filed is the + integer ID of the token. + + (3) words.txt + + Each line in words.txt has two fields separated by space(s): + + word ID + + The first field is the word symbol and the second filed is the + integer ID of the word. + disambig_pattern: + It contains the pattern for disambiguation symbols. + """ + lang_dir = Path(lang_dir) + + lexicon_txt = lang_dir / "lexicon_disambig.txt" + tokens_txt = lang_dir / "tokens.txt" + words_txt = lang_dir / "words.txt" + + assert lexicon_txt.is_file(), lexicon_txt + assert tokens_txt.is_file(), tokens_txt + assert words_txt.is_file(), words_txt + + self._read_lexicon(lexicon_txt) + self._read_tokens(tokens_txt) + self._read_words(words_txt) + + self.disambig_pattern = disambig_pattern + + max_disambig_id = -1 + for s, i in self.token2id.items(): + if self.disambig_pattern.match(s) and i > max_disambig_id: + max_disambig_id = i + + self.max_disambig_id = max_disambig_id + + def _read_lexicon(self, lexicon_txt: str): + word2phones = defaultdict(list) + with open(lexicon_txt, encoding="utf-8") as f: + for line in f: + word_phones = line.strip().split() + assert len(word_phones) >= 2, (word_phones, line) + word = word_phones[0] + phones: str = " ".join(word_phones[1:]) + word2phones[word].append(phones) + # We use a list here since a word may have multiple + # pronunciations + + self.word2phones = word2phones + + def _read_tokens(self, tokens_txt): + token2id = dict() + id2token = dict() + with open(tokens_txt, encoding="utf-8") as f: + for line in f: + token_id = line.strip().split() + assert len(token_id) == 2, token_id + + token = token_id[0] + idx = int(token_id[1]) + + assert token not in token2id, f"Duplicate token {line}" + assert idx not in id2token, f"Duplicate ID {line}" + + token2id[token] = idx + id2token[idx] = token + self.token2id = token2id + self.id2token = id2token + + def _read_words(self, words_txt): + word2id = dict() + id2word = dict() + with open(words_txt, encoding="utf-8") as f: + for line in f: + word_id = line.strip().split() + assert len(word_id) == 2, word_id + + word = word_id[0] + idx = int(word_id[1]) + + assert word not in word2id, f"Duplicate token {line}" + assert idx not in id2word, f"Duplicate ID {line}" + + word2id[word] = idx + id2word[idx] = word + + self.word2id = word2id + self.id2word = id2word + + def __iter__(self) -> Tuple[str, List[str]]: + for word, phones_list in self.word2phones.items(): + for phones in phones_list: + yield word, phones + + def __str__(self): + return str(self.word2phones) + + @property + def tokens(self) -> List[int]: + """Return a list of token IDs excluding those from + disambiguation symbols. + + Caution: + 0 is not a token ID so it is excluded from the return value. + """ + ans = [] + for s in self.token2id: + if not self.disambig_pattern.match(s): + ans.append(self.token2id[s]) + if 0 in ans: + ans.remove(0) + ans.sort() + return ans + + +# See also +# http://vpanayotov.blogspot.com/2012/06/kaldi-decoding-graph-construction.html +def make_lexicon_fst_with_silence( + lexicon: Lexicon, + sil_prob: float = 0.5, + sil_phone: str = "SIL", + attach_symbol_table: bool = True, +) -> kaldifst.StdVectorFst: + phone2id = lexicon.token2id + word2id = lexicon.word2id + + assert sil_phone in phone2id + + assert sil_phone in phone2id, sil_phone + + sil_cost = -1 * math.log(sil_prob) + no_sil_cost = -1 * math.log(1.0 - sil_prob) + + fst = kaldifst.StdVectorFst() + + start_state = fst.add_state() + loop_state = fst.add_state() + sil_state = fst.add_state() + + fst.start = start_state + fst.set_final(state=loop_state, weight=0) + + fst.add_arc( + state=start_state, + arc=kaldifst.StdArc( + ilabel=0, + olabel=0, + weight=no_sil_cost, + nextstate=loop_state, + ), + ) + + fst.add_arc( + state=start_state, + arc=kaldifst.StdArc( + ilabel=0, + olabel=0, + weight=sil_cost, + nextstate=sil_state, + ), + ) + + fst.add_arc( + state=sil_state, + arc=kaldifst.StdArc( + ilabel=phone2id[sil_phone], + olabel=0, + weight=0, + nextstate=loop_state, + ), + ) + + for word, phones in lexicon: + phoneseq = phones.split() + pron_cost = 0 + cur_state = loop_state + + for i in range(len(phoneseq) - 1): + next_state = fst.add_state() + fst.add_arc( + state=cur_state, + arc=kaldifst.StdArc( + ilabel=phone2id[phoneseq[i]], + olabel=word2id[word] if i == 0 else 0, + weight=pron_cost if i == 0 else 0, + nextstate=next_state, + ), + ) + cur_state = next_state + + i = len(phoneseq) - 1 # note: i == -1 if phoneseq is empty. + + fst.add_arc( + state=cur_state, + arc=kaldifst.StdArc( + ilabel=phone2id[phoneseq[i]] if i >= 0 else 0, + olabel=word2id[word] if i <= 0 else 0, + weight=no_sil_cost + (pron_cost if i <= 0 else 0), + nextstate=loop_state, + ), + ) + + fst.add_arc( + state=cur_state, + arc=kaldifst.StdArc( + ilabel=phone2id[phoneseq[i]] if i >= 0 else 0, + olabel=word2id[word] if i <= 0 else 0, + weight=sil_cost + (pron_cost if i <= 0 else 0), + nextstate=sil_state, + ), + ) + + if attach_symbol_table: + isym = kaldifst.SymbolTable() + for p, i in phone2id.items(): + isym.add_symbol(symbol=p, key=i) + fst.input_symbols = isym + + osym = kaldifst.SymbolTable() + for w, i in word2id.items(): + osym.add_symbol(symbol=w, key=i) + fst.output_symbols = osym + + return fst + + +def make_lexicon_fst_no_silence( + lexicon: Lexicon, + attach_symbol_table: bool = True, +) -> kaldifst.StdVectorFst: + phone2id = lexicon.token2id + word2id = lexicon.word2id + + fst = kaldifst.StdVectorFst() + + start_state = fst.add_state() + fst.start = start_state + fst.set_final(state=start_state, weight=0) + + for word, phones in lexicon: + phoneseq = phones.split() + pron_cost = 0 + cur_state = start_state + + for i in range(len(phoneseq) - 1): + next_state = fst.add_state() + fst.add_arc( + state=cur_state, + arc=kaldifst.StdArc( + ilabel=phone2id[phoneseq[i]], + olabel=word2id[word] if i == 0 else 0, + weight=pron_cost if i == 0 else 0, + nextstate=next_state, + ), + ) + cur_state = next_state + + i = len(phoneseq) - 1 # note: i == -1 if phoneseq is empty. + + fst.add_arc( + state=cur_state, + arc=kaldifst.StdArc( + ilabel=phone2id[phoneseq[i]] if i >= 0 else 0, + olabel=word2id[word] if i <= 0 else 0, + weight=pron_cost if i <= 0 else 0, + nextstate=start_state, + ), + ) + + if attach_symbol_table: + isym = kaldifst.SymbolTable() + for p, i in phone2id.items(): + isym.add_symbol(symbol=p, key=i) + fst.input_symbols = isym + + osym = kaldifst.SymbolTable() + for w, i in word2id.items(): + osym.add_symbol(symbol=w, key=i) + fst.output_symbols = osym + + return fst diff --git a/icefall/ctc/test_ctc_topo.py b/icefall/ctc/test_ctc_topo.py new file mode 100755 index 000000000..4d4667209 --- /dev/null +++ b/icefall/ctc/test_ctc_topo.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +from pathlib import Path + +import graphviz +import kaldifst +import sentencepiece as spm +from prepare_lang import ( + Lexicon, + make_lexicon_fst_no_silence, + make_lexicon_fst_with_silence, +) +from topo import add_disambig_self_loops, add_one, build_standard_ctc_topo + + +def test_yesno(): + lang_dir = "/Users/fangjun/open-source/icefall/egs/yesno/ASR/data/lang_phone" + if not Path(lang_dir).is_dir(): + print(f"{lang_dir} does not exist! Skip testing") + return + lexicon = Lexicon(lang_dir) + max_token_id = max(lexicon.tokens) + + H = build_standard_ctc_topo(max_token_id=max_token_id) + + isym = kaldifst.SymbolTable() + isym.add_symbol(symbol="", key=0) + for i in range(1, max_token_id + 1): + isym.add_symbol(symbol=lexicon.id2token[i], key=i) + + osym = kaldifst.SymbolTable() + osym.add_symbol(symbol="", key=0) + for i in range(1, max_token_id + 1): + osym.add_symbol(symbol=lexicon.id2token[i], key=i) + + H.input_symbols = isym + H.output_symbols = osym + + fst_dot = kaldifst.draw(H, acceptor=False, portrait=True) + source = graphviz.Source(fst_dot) + source.render(outfile="standard_ctc_topo_yesno.pdf") + # See the link below to visualize the above PDF + # https://t.ly/7uXZ9 + + # Now test HL + + # We need to add one to all tokens since we want to use ID 0 + # for epsilon + add_one(H, treat_ilabel_zero_specially=False, update_olabel=True) + + add_disambig_self_loops( + H, + start=lexicon.token2id["#0"] + 1, + end=lexicon.max_disambig_id, + ) + + fst_dot = kaldifst.draw(H, acceptor=False, portrait=True) + source = graphviz.Source(fst_dot) + source.render(outfile="standard_ctc_topo_disambig_yesno.pdf") + + L = make_lexicon_fst_with_silence(lexicon) + + # We also need to change the input labels of L + add_one(L, treat_ilabel_zero_specially=True, update_olabel=False) + + H.output_symbols = None + + kaldifst.arcsort(H, sort_type="olabel") + kaldifst.arcsort(L, sort_type="ilabel") + HL = kaldifst.compose(H, L) + + lexicon.id2token[0] = "" + lexicon.token2id[""] = 0 + + isym = kaldifst.SymbolTable() + isym.add_symbol(symbol="", key=0) + for i in range(0, lexicon.max_disambig_id + 1): + isym.add_symbol(symbol=lexicon.id2token[i], key=i + 1) + + osym = kaldifst.SymbolTable() + for i, word in lexicon.id2word.items(): + osym.add_symbol(symbol=word, key=i) + + HL.input_symbols = isym + HL.output_symbols = osym + + fst_dot = kaldifst.draw(HL, acceptor=False, portrait=True) + source = graphviz.Source(fst_dot) + source.render(outfile="HL_yesno.pdf") + + +def test_librispeech(): + lang_dir = ( + "/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/data/lang_bpe_500" + ) + + if not Path(lang_dir).is_dir(): + print(f"{lang_dir} does not exist! Skip testing") + return + + lexicon = Lexicon(lang_dir) + HL = kaldifst.StdVectorFst.read(lang_dir + "/HL.fst") + + sp = spm.SentencePieceProcessor() + sp.load(lang_dir + "/bpe.model") + + i = lexicon.word2id["HELLOA"] + k = lexicon.word2id["WORLD"] + print(i, k) + s = f""" + 0 1 {i} {i} + 1 2 {k} {k} + 2 + """ + fst = kaldifst.compile( + s=s, + acceptor=False, + ) + + L = make_lexicon_fst_no_silence(lexicon, attach_symbol_table=False) + kaldifst.arcsort(L, sort_type="olabel") + with open("L.fst.txt", "w") as f: + print(L, file=f) + + fst = kaldifst.compose(L, fst) + print(fst) + fst_dot = kaldifst.draw(fst, acceptor=False, portrait=True) + source = graphviz.Source(fst_dot) + source.render(outfile="a.pdf") + print(sp.encode(["HELLOA", "WORLD"])) + + +def main(): + test_yesno() + test_librispeech() + + +if __name__ == "__main__": + main() diff --git a/icefall/ctc/test_prepare_lang.py b/icefall/ctc/test_prepare_lang.py new file mode 100755 index 000000000..6c4b9e510 --- /dev/null +++ b/icefall/ctc/test_prepare_lang.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +from pathlib import Path + +import graphviz +import kaldifst +from prepare_lang import Lexicon, make_lexicon_fst_with_silence + + +def test_yesno(): + lang_dir = "/Users/fangjun/open-source/icefall/egs/yesno/ASR/data/lang_phone" + if not Path(lang_dir).is_dir(): + print(f"{lang_dir} does not exist! Skip testing") + return + + lexicon = Lexicon(lang_dir) + + L = make_lexicon_fst_with_silence(lexicon) + + isym = kaldifst.SymbolTable() + for i, token in lexicon.id2token.items(): + isym.add_symbol(symbol=token, key=i) + + osym = kaldifst.SymbolTable() + for i, word in lexicon.id2word.items(): + osym.add_symbol(symbol=word, key=i) + + L.input_symbols = isym + L.output_symbols = osym + fst_dot = kaldifst.draw(L, acceptor=False, portrait=True) + source = graphviz.Source(fst_dot) + source.render(outfile="L_yesno.pdf") + # See the link below to visualize the above PDF + # https://t.ly/jMfXW + + +def main(): + test_yesno() + + +if __name__ == "__main__": + main() diff --git a/icefall/ctc/topo.py b/icefall/ctc/topo.py new file mode 100644 index 000000000..6a96dd038 --- /dev/null +++ b/icefall/ctc/topo.py @@ -0,0 +1,137 @@ +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +import kaldifst + + +# Note the name contains `standard`; it means there will be non-standard +# topologies. +def build_standard_ctc_topo(max_token_id: int) -> kaldifst.StdVectorFst: + """Build a standard CTC topology. + + Args: + Maximum valid token ID. We assume token IDs are contiguous + and starts from 0. In other words, the vocabulary size is + ``max_token_id + 1``. We assume the ID of the blank symbol is 0. + """ + # Token ID starts from 0 and there are as many states as the + # number of tokens. + # + # Note that epsilon is not a token and the token with ID 0 in tokens.txt + # is not an epsilon. It means input label 0 of the resulting FST does + # not represent an epsilon. + # + # You can use the function `add_one()` to modify the input/output labels + # of the resulting FST + + num_states = max_token_id + 1 + + # Step 1: Create as many states as the number of tokens. + # Each state is a final state + fst = kaldifst.StdVectorFst() + for i in range(num_states): + s = fst.add_state() + fst.set_final(state=s, weight=0) + + # Step 2: Set state 0 as the start state. + # We assume the ID of the blank symbol is 0. + fst.start = 0 + + # Step 3: Build a fully connected graph. + for i in range(num_states): + for k in range(num_states): + fst.add_arc( + state=i, + arc=kaldifst.StdArc( + ilabel=k, + olabel=k if i != k else 0, # if i==k, it is a self loop + weight=0, + nextstate=k, + ), + ) + # Please see ./test_ctc_topo.py if you want to know what the resulting + # FST looks like + + return fst + + +def add_one( + fst: kaldifst.StdVectorFst, + treat_ilabel_zero_specially: bool, + update_olabel: bool, +) -> None: + """Modify the input and output labels of the given FST in-place. + + Args: + fst: + The FST to be modified. It is changed in-place. + treat_ilabel_zero_specially: + If True, then every non-zero input label is increased by one and the + zero input label is not changed. + If False, then every input label is increased by one. + update_olabel: + If False, the output label is not changed. + If True, then every non-zero output label is increased by one. + In either case, output label with 0 is not changed. + """ + for state in kaldifst.StateIterator(fst): + for arc in kaldifst.ArcIterator(fst, state): + # If treat_ilabel_zero_specially is False, we always change it + # Otherwise, we only change non-zero input labels + if treat_ilabel_zero_specially is False or arc.ilabel != 0: + arc.ilabel += 1 + + if update_olabel and arc.olabel != 0: + arc.olabel += 1 + + if fst.input_symbols is not None: + input_symbols = kaldifst.SymbolTable() + input_symbols.add_symbol(symbol="", key=0) + + for i in range(0, fst.input_symbols.num_symbols()): + s = fst.input_symbols.find(i) + input_symbols.add_symbol(symbol=s, key=i + 1) + + fst.input_symbols = input_symbols + + if update_olabel and fst.output_symbols is not None: + output_symbols = kaldifst.SymbolTable() + output_symbols.add_symbol(symbol="", key=0) + + for i in range(0, fst.output_symbols.num_symbols()): + s = fst.output_symbols.find(i) + output_symbols.add_symbol(symbol=s, key=i + 1) + + fst.output_symbols = output_symbols + + +def add_disambig_self_loops(fst: kaldifst.StdVectorFst, start: int, end: int): + """Add self-loops to each state. + + For each disambig symbol, we add a self-loop with input label disambig_id + and output label diambig_id of that disambig symbol. + + Args: + fst: + It is changed in-place. + start: + The ID of #0 + end: + The ID of the last disambig symbol. For instance if there are 3 + disambig symbols ``#0``, ``#1``, and ``#2``, then ``end`` is the ID + of ``#2``. + """ + for state in kaldifst.StateIterator(fst): + for i in range(start, end + 1): + fst.add_arc( + state=state, + arc=kaldifst.StdArc( + ilabel=i, + olabel=i, + weight=0, + nextstate=state, + ), + ) + + if fst.output_symbols: + for i in range(start, end + 1): + fst.output_symbols.add_symbol(symbol=f"#{i-start}", key=i) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 98870684e..700dc1500 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -23,6 +23,7 @@ from typing import Optional, Tuple, List import torch from torch import Tensor, nn + class TensorDiagnosticOptions(object): """Options object for tensor diagnostics: @@ -77,11 +78,11 @@ def get_tensor_stats( elif stats_type == "abs": x = x.abs() elif stats_type == "rms": - x = x ** 2 + x = x**2 elif stats_type == "positive": x = (x > 0).to(dtype=torch.float) else: - assert stats_type in [ "value", "max", "min" ] + assert stats_type in ["value", "max", "min"] sum_dims = [d for d in range(x.ndim) if d != dim] if len(sum_dims) > 0: @@ -121,10 +122,10 @@ class TensorDiagnostic(object): self.class_name = None # will assign in accumulate() self.stats = None # we'll later assign a list to self.stats. - # It's a list of dicts, indexed by dim (i.e. by the - # axis of the tensor). The dicts, in turn, are - # indexed by `stats-type` which are strings in - # ["abs", "max", "min", "positive", "value", "rms"]. + # It's a list of dicts, indexed by dim (i.e. by the + # axis of the tensor). The dicts, in turn, are + # indexed by `stats-type` which are strings in + # ["abs", "max", "min", "positive", "value", "rms"]. # scalar_stats contains some analysis of the activations and gradients, self.scalar_stats = None @@ -139,7 +140,6 @@ class TensorDiagnostic(object): # only adding a new element to the list if there was a different dim. # if the string in the key is "eigs", if we detect a length mismatch we put None as the value. - def accumulate(self, x, class_name: Optional[str] = None): """ Accumulate tensors. @@ -193,17 +193,12 @@ class TensorDiagnostic(object): done = True break if not done: - if ( - this_dim_stats[stats_type] != [] - and stats_type == "eigs" - ): + if this_dim_stats[stats_type] != [] and stats_type == "eigs": # >1 size encountered on this dim, e.g. it's a batch or time dimension, # don't accumulat "eigs" stats type, it uses too much memory this_dim_stats[stats_type] = None else: - this_dim_stats[stats_type].append( - TensorAndCount(stats, count) - ) + this_dim_stats[stats_type].append(TensorAndCount(stats, count)) def print_diagnostics(self): """Print diagnostics for each dimension of the tensor.""" @@ -220,8 +215,11 @@ class TensorDiagnostic(object): for r, v in zip(rms_stats_list, value_stats_list): stddev_stats_list.append( # r.count and v.count should be the same, but we don't check this. - TensorAndCount(r.tensor - v.tensor * v.tensor / (v.count + 1.0e-20), - r.count)) + TensorAndCount( + r.tensor - v.tensor * v.tensor / (v.count + 1.0e-20), + r.count, + ) + ) this_dim_stats["stddev"] = stddev_stats_list for stats_type, stats_list in this_dim_stats.items(): @@ -232,7 +230,6 @@ class TensorDiagnostic(object): assert stats_type == "eigs" continue - def get_count(count): return 1 if stats_type in ["max", "min"] else count @@ -250,22 +247,20 @@ class TensorDiagnostic(object): eigs, _ = torch.symeig(stats) stats = eigs.abs().sqrt() except: # noqa - print( - "Error getting eigenvalues, trying another method." - ) + print("Error getting eigenvalues, trying another method.") eigs, _ = torch.eig(stats) stats = eigs.norm(dim=1).sqrt() # sqrt so it reflects data magnitude, like stddev- not variance - if stats_type in [ "rms", "stddev" ]: + if stats_type in ["rms", "stddev"]: # we stored the square; after aggregation we need to take sqrt. stats = stats.sqrt() # if `summarize` we print percentiles of the stats; else, # we print out individual elements. - summarize = ( - len(stats_list) > 1 - ) or self.opts.dim_is_summarized(stats.numel()) + summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized( + stats.numel() + ) if summarize: # usually `summarize` will be true # print out percentiles. stats = stats.sort()[0] @@ -282,15 +277,15 @@ class TensorDiagnostic(object): ans = stats.tolist() ans = ["%.2g" % x for x in ans] ans = "[" + " ".join(ans) + "]" - if stats_type in [ "value", "rms", "stddev", "eigs" ]: + if stats_type in ["value", "rms", "stddev", "eigs"]: # This norm is useful because it is strictly less than the largest # sqrt(eigenvalue) of the variance, which we print out, and shows, # speaking in an approximate way, how much of that largest eigenvalue # can be attributed to the mean of the distribution. - norm = (stats ** 2).sum().sqrt().item() + norm = (stats**2).sum().sqrt().item() ans += f", norm={norm:.2g}" mean = stats.mean().item() - rms = (stats ** 2).mean().sqrt().item() + rms = (stats**2).mean().sqrt().item() ans += f", mean={mean:.3g}, rms={rms:.3g}" # OK, "ans" contains the actual stats, e.g. @@ -298,11 +293,11 @@ class TensorDiagnostic(object): sizes = [x.tensor.shape[0] for x in stats_list] size_str = ( - f"{sizes[0]}" - if len(sizes) == 1 - else f"{min(sizes)}..{max(sizes)}" + f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}" + ) + maybe_class_name = ( + f" type={self.class_name}," if self.class_name is not None else "" ) - maybe_class_name = f" type={self.class_name}," if self.class_name is not None else "" print( f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, {stats_type} {ans}" ) @@ -330,7 +325,6 @@ class ScalarDiagnostic(object): self.sum_gradsq = None self.sum_abs_grad = None - def accumulate_input(self, x: Tensor, class_name: Optional[str] = None): """ Called in forward pass. @@ -347,8 +341,10 @@ class ScalarDiagnostic(object): limit = 10 if len(self.saved_inputs) > limit: - print(f"ERROR: forward pass called for this module over {limit} times with no backward pass. " - f" Will not accumulate scalar stats.") + print( + f"ERROR: forward pass called for this module over {limit} times with no backward pass. " + f" Will not accumulate scalar stats." + ) self.is_ok = False return self.saved_inputs.append(x) @@ -359,11 +355,15 @@ class ScalarDiagnostic(object): if self.is_forward_pass: self.is_forward_pass = False - last_shape = 'n/a' if len(self.saved_inputs) == 0 else self.saved_inputs[-1].shape + last_shape = ( + "n/a" if len(self.saved_inputs) == 0 else self.saved_inputs[-1].shape + ) if len(self.saved_inputs) == 0 or grad.shape != last_shape: - print(f"ERROR: shape mismatch or no forward activation present when backward " - f"pass called: grad shape ={tuple(grad.shape)}, num-saved-inputs={len(self.saved_inputs)}" - f", shape-of-last-saved-input={last_shape}") + print( + f"ERROR: shape mismatch or no forward activation present when backward " + f"pass called: grad shape ={tuple(grad.shape)}, num-saved-inputs={len(self.saved_inputs)}" + f", shape-of-last-saved-input={last_shape}" + ) self.is_ok = False return @@ -384,11 +384,19 @@ class ScalarDiagnostic(object): self.tick_scale = float(x_abs_sorted[index] / num_ticks_per_side) # integerize from tick * (-num ticks_per_side .. num_ticks_per_side - 1] - self.counts = torch.zeros(2 * num_ticks_per_side, dtype=torch.long, device=x.device) - self.sum_grad = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device) + self.counts = torch.zeros( + 2 * num_ticks_per_side, dtype=torch.long, device=x.device + ) + self.sum_grad = torch.zeros( + 2 * num_ticks_per_side, dtype=torch.double, device=x.device + ) # sum_gradsq is for getting error bars. - self.sum_gradsq = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device) - self.sum_abs_grad = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device) + self.sum_gradsq = torch.zeros( + 2 * num_ticks_per_side, dtype=torch.double, device=x.device + ) + self.sum_abs_grad = torch.zeros( + 2 * num_ticks_per_side, dtype=torch.double, device=x.device + ) # this will round down. x = (x / self.tick_scale).to(torch.long) @@ -397,20 +405,21 @@ class ScalarDiagnostic(object): self.counts.index_add_(dim=0, index=x, source=torch.ones_like(x)) self.sum_grad.index_add_(dim=0, index=x, source=grad.to(torch.double)) - self.sum_gradsq.index_add_(dim=0, index=x, source=(grad*grad).to(torch.double)) + self.sum_gradsq.index_add_( + dim=0, index=x, source=(grad * grad).to(torch.double) + ) self.sum_abs_grad.index_add_(dim=0, index=x, source=grad.abs().to(torch.double)) - def print_diagnostics(self): """Print diagnostics.""" if self.is_ok is False or self.counts is None: print(f"Warning: no stats accumulated for {self.name}, is_ok={self.is_ok}") return - counts = self.counts.to('cpu') - sum_grad = self.sum_grad.to(device='cpu', dtype=torch.float32) - sum_gradsq = self.sum_gradsq.to(device='cpu', dtype=torch.float32) - sum_abs_grad = self.sum_abs_grad.to(device='cpu', dtype=torch.float32) + counts = self.counts.to("cpu") + sum_grad = self.sum_grad.to(device="cpu", dtype=torch.float32) + sum_gradsq = self.sum_gradsq.to(device="cpu", dtype=torch.float32) + sum_abs_grad = self.sum_abs_grad.to(device="cpu", dtype=torch.float32) counts_cumsum = counts.cumsum(dim=0) counts_tot = counts_cumsum[-1] @@ -433,19 +442,22 @@ class ScalarDiagnostic(object): bin_abs_grad = torch.zeros(num_bins) bin_abs_grad.index_add_(dim=0, index=bin_indexes, source=sum_abs_grad) - avg_grad = (bin_grad / bin_counts) + avg_grad = bin_grad / bin_counts avg_grad_stddev = (bin_gradsq / bin_counts).sqrt() - bin_boundary_counts = torch.arange(num_bins + 1, dtype=torch.long) * counts_per_bin + bin_boundary_counts = ( + torch.arange(num_bins + 1, dtype=torch.long) * counts_per_bin + ) bin_tick_indexes = torch.searchsorted(counts_cumsum, bin_boundary_counts) # boundaries are the "x" values between the bins, e.g. corresponding to the # locations of percentiles of the distribution. num_ticks_per_side = counts.numel() // 2 bin_boundaries = (bin_tick_indexes - num_ticks_per_side) * self.tick_scale - bin_grad = bin_grad / (bin_counts + 1) - bin_conf_interval = bin_gradsq.sqrt() / (bin_counts + 1) # consider this a standard deviation. + bin_conf_interval = bin_gradsq.sqrt() / ( + bin_counts + 1 + ) # consider this a standard deviation. # bin_grad / bin_abs_grad will give us a sense for how important in a practical sense, # the gradients are. bin_abs_grad = bin_abs_grad / (bin_counts + 1) @@ -458,8 +470,9 @@ class ScalarDiagnostic(object): x = "[" + " ".join(x) + "]" return x - - maybe_class_name = f" type={self.class_name}," if self.class_name is not None else "" + maybe_class_name = ( + f" type={self.class_name}," if self.class_name is not None else "" + ) print( f"module={self.name},{maybe_class_name} bin-boundaries={tensor_to_str(bin_boundaries)}, " @@ -467,7 +480,6 @@ class ScalarDiagnostic(object): ) - class ModelDiagnostic(object): """This class stores diagnostics for all tensors in the torch.nn.Module. @@ -485,9 +497,8 @@ class ModelDiagnostic(object): self.opts = opts self.diagnostics = dict() - def __getitem__(self, name: str): - T = ScalarDiagnostic if name[-7:] == '.scalar' else TensorDiagnostic + T = ScalarDiagnostic if name[-7:] == ".scalar" else TensorDiagnostic if name not in self.diagnostics: self.diagnostics[name] = T(self.opts, name) return self.diagnostics[name] @@ -502,18 +513,19 @@ def get_class_name(module: nn.Module): ans = type(module).__name__ # we put the below in try blocks in case anyone is using a different version of these modules that # might have different member names. - if ans == 'Balancer' or ans == 'ActivationBalancer': + if ans == "Balancer" or ans == "ActivationBalancer": try: - ans += f'[{float(module.min_positive)},{float(module.max_positive)},{float(module.min_abs)},{float(module.max_abs)}]' + ans += f"[{float(module.min_positive)},{float(module.max_positive)},{float(module.min_abs)},{float(module.max_abs)}]" except: pass - elif ans == 'AbsValuePenalizer': + elif ans == "AbsValuePenalizer": try: - ans += f'[{module.limit}]' + ans += f"[{module.limit}]" except: pass return ans + def attach_diagnostics( model: nn.Module, opts: Optional[TensorDiagnosticOptions] = None ) -> ModelDiagnostic: @@ -538,73 +550,85 @@ def attach_diagnostics( if name == "": name = "" - - # Setting model_diagnostic=ans and n=name below, instead of trying to # capture the variables, ensures that we use the current values. # (this matters for `name`, since the variable gets overwritten). # These closures don't really capture by value, only by # "the final value the variable got in the function" :-( - def forward_hook( - _module, _input, _output, _model_diagnostic=ans, _name=name - ): + def forward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name): if isinstance(_output, tuple) and len(_output) == 1: _output = _output[0] - if isinstance(_output, Tensor) and _output.dtype in ( torch.float32, torch.float16, torch.float64 ): - _model_diagnostic[f"{_name}.output"].accumulate(_output, - class_name=get_class_name(_module)) + if isinstance(_output, Tensor) and _output.dtype in ( + torch.float32, + torch.float16, + torch.float64, + ): + _model_diagnostic[f"{_name}.output"].accumulate( + _output, class_name=get_class_name(_module) + ) elif isinstance(_output, tuple): for i, o in enumerate(_output): - if o.dtype in ( torch.float32, torch.float16, torch.float64 ): - _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o, - class_name=get_class_name(_module)) + if o.dtype in (torch.float32, torch.float16, torch.float64): + _model_diagnostic[f"{_name}.output[{i}]"].accumulate( + o, class_name=get_class_name(_module) + ) - def backward_hook( - _module, _input, _output, _model_diagnostic=ans, _name=name - ): + def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name): if isinstance(_output, tuple) and len(_output) == 1: _output = _output[0] - if isinstance(_output, Tensor) and _output.dtype in ( torch.float32, torch.float16, torch.float64 ): - _model_diagnostic[f"{_name}.grad"].accumulate(_output, - class_name=get_class_name(_module)) + if isinstance(_output, Tensor) and _output.dtype in ( + torch.float32, + torch.float16, + torch.float64, + ): + _model_diagnostic[f"{_name}.grad"].accumulate( + _output, class_name=get_class_name(_module) + ) elif isinstance(_output, tuple): for i, o in enumerate(_output): - if o.dtype in ( torch.float32, torch.float16, torch.float64 ): - _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o, - class_name=get_class_name(_module)) - + if o.dtype in (torch.float32, torch.float16, torch.float64): + _model_diagnostic[f"{_name}.grad[{i}]"].accumulate( + o, class_name=get_class_name(_module) + ) module.register_forward_hook(forward_hook) module.register_backward_hook(backward_hook) - if type(module).__name__ in ["Sigmoid", "Tanh", "ReLU", "TanSwish", "Swish", "DoubleSwish", "Swoosh"]: + if type(module).__name__ in [ + "Sigmoid", + "Tanh", + "ReLU", + "TanSwish", + "Swish", + "DoubleSwish", + "Swoosh", + ]: # For these specific module types, accumulate some additional diagnostics # that can help us improve the activation function. These require a lot of memory, # to save the forward activations, so limit this to some select classes. # Note: this will not work correctly for all model types. def scalar_forward_hook( - _module, _input, _output, _model_diagnostic=ans, _name=name + _module, _input, _output, _model_diagnostic=ans, _name=name ): if isinstance(_input, tuple): - _input, = _input + (_input,) = _input assert isinstance(_input, Tensor) - _model_diagnostic[f"{_name}.scalar"].accumulate_input(_input, - class_name=get_class_name(_module)) + _model_diagnostic[f"{_name}.scalar"].accumulate_input( + _input, class_name=get_class_name(_module) + ) def scalar_backward_hook( - _module, _input, _output, _model_diagnostic=ans, _name=name + _module, _input, _output, _model_diagnostic=ans, _name=name ): if isinstance(_output, tuple): - _output, = _output + (_output,) = _output assert isinstance(_output, Tensor) _model_diagnostic[f"{_name}.scalar"].accumulate_output_grad(_output) module.register_forward_hook(scalar_forward_hook) module.register_backward_hook(scalar_backward_hook) - - for name, parameter in model.named_parameters(): def param_backward_hook( diff --git a/icefall/otc_graph_compiler.py b/icefall/otc_graph_compiler.py new file mode 100644 index 000000000..bfd679452 --- /dev/null +++ b/icefall/otc_graph_compiler.py @@ -0,0 +1,246 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# 2023 Johns Hopkins University (author: Dongji Gao) +# +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from pathlib import Path +from typing import List, Union + +import k2 +import sentencepiece as spm +import torch + +from icefall.utils import str2bool + + +class OtcTrainingGraphCompiler(object): + def __init__( + self, + lang_dir: Path, + otc_token: str, + device: Union[str, torch.device] = "cpu", + sos_token: str = "", + eos_token: str = "", + initial_bypass_weight: float = 0.0, + initial_self_loop_weight: float = 0.0, + bypass_weight_decay: float = 0.0, + self_loop_weight_decay: float = 0.0, + ) -> None: + """ + Args: + lang_dir: + This directory is expected to contain the following files: + + - bpe.model + - words.txt + otc_token: + The special token in OTC that represent all non-blank tokens + device: + It indicates CPU or CUDA. + sos_token: + The word piece that represents sos. + eos_token: + The word piece that represents eos. + """ + lang_dir = Path(lang_dir) + bpe_model_file = lang_dir / "bpe.model" + sp = spm.SentencePieceProcessor() + sp.load(str(bpe_model_file)) + self.sp = sp + self.token_table = k2.SymbolTable.from_file(lang_dir / "tokens.txt") + + self.otc_token = otc_token + assert self.otc_token in self.token_table + + self.device = device + + self.sos_id = self.sp.piece_to_id(sos_token) + self.eos_id = self.sp.piece_to_id(eos_token) + + assert self.sos_id != self.sp.unk_id() + assert self.eos_id != self.sp.unk_id() + + max_token_id = self.get_max_token_id() + ctc_topo = k2.ctc_topo(max_token_id, modified=False) + self.ctc_topo = ctc_topo.to(self.device) + + self.initial_bypass_weight = initial_bypass_weight + self.initial_self_loop_weight = initial_self_loop_weight + self.bypass_weight_decay = bypass_weight_decay + self.self_loop_weight_decay = self_loop_weight_decay + + def get_max_token_id(self): + max_token_id = 0 + for symbol in self.token_table.symbols: + if not symbol.startswith("#"): + max_token_id = max(self.token_table[symbol], max_token_id) + assert max_token_id > 0 + + return max_token_id + + def make_arc( + self, + from_state: int, + to_state: int, + symbol: Union[str, int], + weight: float, + ): + return f"{from_state} {to_state} {symbol} {weight}" + + def texts_to_ids(self, texts: List[str]) -> List[List[int]]: + """Convert a list of texts to a list-of-list of piece IDs. + + Args: + texts: + It is a list of strings. Each string consists of space(s) + separated words. An example containing two strings is given below: + + ['HELLO ICEFALL', 'HELLO k2'] + Returns: + Return a list-of-list of piece IDs. + """ + return self.sp.encode(texts, out_type=int) + + def compile( + self, + texts: List[str], + allow_bypass_arc: str2bool = True, + allow_self_loop_arc: str2bool = True, + bypass_weight: float = 0.0, + self_loop_weight: float = 0.0, + ) -> k2.Fsa: + """Build a OTC graph from a texts (list of words). + + Args: + texts: + A list of strings. Each string contains a sentence for an utterance. + A sentence consists of spaces separated words. An example `texts` + looks like: + ['hello icefall', 'CTC training with k2'] + allow_bypass_arc: + Whether to add bypass arc to training graph for substitution + and insertion errors (wrong or extra words in the transcript). + allow_self_loop_arc: + Whether to add self-loop arc to training graph for deletion + errors (missing words in the transcript). + bypass_weight: + Weight associated with bypass arc. + self_loop_weight: + Weight associated with self-loop arc. + + Return: + Return an FsaVec, which is the result of composing a + CTC topology with OTC FSAs constructed from the given texts. + """ + + transcript_fsa = self.convert_transcript_to_fsa( + texts, + self.otc_token, + allow_bypass_arc, + allow_self_loop_arc, + bypass_weight, + self_loop_weight, + ) + transcript_fsa = transcript_fsa.to(self.device) + fsa_with_self_loop = k2.remove_epsilon_and_add_self_loops(transcript_fsa) + fsa_with_self_loop = k2.arc_sort(fsa_with_self_loop) + + graph = k2.compose( + self.ctc_topo, + fsa_with_self_loop, + treat_epsilons_specially=False, + ) + assert graph.requires_grad is False + + return graph + + def convert_transcript_to_fsa( + self, + texts: List[str], + otc_token: str, + allow_bypass_arc: str2bool = True, + allow_self_loop_arc: str2bool = True, + bypass_weight: float = 0.0, + self_loop_weight: float = 0.0, + ): + otc_token_id = self.token_table[otc_token] + + transcript_fsa_list = [] + for text in texts: + text_piece_ids = [] + + for word in text.split(): + piece_ids = self.sp.encode(word, out_type=int) + text_piece_ids.append(piece_ids) + + arcs = [] + start_state = 0 + cur_state = start_state + next_state = 1 + + for piece_ids in text_piece_ids: + bypass_cur_state = cur_state + + if allow_self_loop_arc: + self_loop_arc = self.make_arc( + cur_state, + cur_state, + otc_token_id, + self_loop_weight, + ) + arcs.append(self_loop_arc) + + for piece_id in piece_ids: + arc = self.make_arc(cur_state, next_state, piece_id, 0.0) + arcs.append(arc) + + cur_state = next_state + next_state += 1 + + bypass_next_state = cur_state + if allow_bypass_arc: + bypass_arc = self.make_arc( + bypass_cur_state, + bypass_next_state, + otc_token_id, + bypass_weight, + ) + arcs.append(bypass_arc) + bypass_cur_state = cur_state + + if allow_self_loop_arc: + self_loop_arc = self.make_arc( + cur_state, + cur_state, + otc_token_id, + self_loop_weight, + ) + arcs.append(self_loop_arc) + + # Deal with final state + final_state = next_state + final_arc = self.make_arc(cur_state, final_state, -1, 0.0) + arcs.append(final_arc) + arcs.append(f"{final_state}") + sorted_arcs = sorted(arcs, key=lambda a: int(a.split()[0])) + + transcript_fsa = k2.Fsa.from_str("\n".join(sorted_arcs)) + transcript_fsa = k2.arc_sort(transcript_fsa) + transcript_fsa_list.append(transcript_fsa) + + transcript_fsa_vec = k2.create_fsa_vec(transcript_fsa_list) + + return transcript_fsa_vec diff --git a/icefall/profiler.py b/icefall/profiler.py index dc76ebebc..49e138579 100644 --- a/icefall/profiler.py +++ b/icefall/profiler.py @@ -70,25 +70,17 @@ class FlopsProfiler(object): module_flop_count.append([]) if not hasattr(module, "__pre_hook_handle__"): - module.__pre_hook_handle__ = module.register_forward_pre_hook( - pre_hook - ) + module.__pre_hook_handle__ = module.register_forward_pre_hook(pre_hook) def post_hook(module, input, output): if module_flop_count: - module.__flops__ += sum( - [elem[1] for elem in module_flop_count[-1]] - ) + module.__flops__ += sum([elem[1] for elem in module_flop_count[-1]]) module_flop_count.pop() if not hasattr(module, "__post_hook_handle__"): - module.__post_hook_handle__ = module.register_forward_hook( - post_hook - ) + module.__post_hook_handle__ = module.register_forward_hook(post_hook) - self.model.apply( - partial(register_module_hooks, ignore_list=ignore_list) - ) + self.model.apply(partial(register_module_hooks, ignore_list=ignore_list)) self.started = True self.func_patched = True @@ -194,9 +186,7 @@ def _prelu_flops_compute(input: Tensor, weight: Tensor): return input.numel() -def _elu_flops_compute( - input: Tensor, alpha: float = 1.0, inplace: bool = False -): +def _elu_flops_compute(input: Tensor, alpha: float = 1.0, inplace: bool = False): return input.numel() @@ -259,9 +249,7 @@ def _conv_flops_compute( output_dims.append(output_dim) filters_per_channel = out_channels // groups - conv_per_position_macs = ( - int(_prod(kernel_dims)) * in_channels * filters_per_channel - ) + conv_per_position_macs = int(_prod(kernel_dims)) * in_channels * filters_per_channel active_elements_count = batch_size * int(_prod(output_dims)) overall_conv_macs = conv_per_position_macs * active_elements_count overall_conv_flops = 2 * overall_conv_macs @@ -297,7 +285,6 @@ def _conv_trans_flops_compute( output_dims = [] for idx, input_dim in enumerate(input_dims): - output_dim = ( input_dim + 2 * paddings[idx] @@ -310,9 +297,7 @@ def _conv_trans_flops_compute( dilations = dilation if type(dilation) is tuple else (dilation, dilation) filters_per_channel = out_channels // groups - conv_per_position_macs = ( - int(_prod(kernel_dims)) * in_channels * filters_per_channel - ) + conv_per_position_macs = int(_prod(kernel_dims)) * in_channels * filters_per_channel active_elements_count = batch_size * int(_prod(input_dims)) overall_conv_macs = conv_per_position_macs * active_elements_count overall_conv_flops = 2 * overall_conv_macs @@ -389,9 +374,7 @@ def _upsample_flops_compute(input, **kwargs): else: return int(size), 0 scale_factor = kwargs.get("scale_factor", None) - assert ( - scale_factor is not None - ), "either size or scale_factor should be defined" + assert scale_factor is not None, "either size or scale_factor should be defined" flops = input.numel() if isinstance(scale_factor, tuple) and len(scale_factor) == len(input): flops * int(_prod(scale_factor)) @@ -593,12 +576,8 @@ def _patch_functionals(): F.embedding = wrapFunc(F.embedding, _embedding_flops_compute) # swoosh functions in k2 - k2.swoosh_l_forward = wrapFunc( - k2.swoosh_l_forward, _k2_swoosh_flops_compute - ) - k2.swoosh_r_forward = wrapFunc( - k2.swoosh_r_forward, _k2_swoosh_flops_compute - ) + k2.swoosh_l_forward = wrapFunc(k2.swoosh_l_forward, _k2_swoosh_flops_compute) + k2.swoosh_r_forward = wrapFunc(k2.swoosh_r_forward, _k2_swoosh_flops_compute) k2.swoosh_l = wrapFunc(k2.swoosh_l, _k2_swoosh_flops_compute) k2.swoosh_r = wrapFunc(k2.swoosh_r, _k2_swoosh_flops_compute) @@ -612,9 +591,7 @@ def _patch_tensor_methods(): torch.Tensor.bmm = wrapFunc(torch.Tensor.bmm, _matmul_flops_compute) torch.addmm = wrapFunc(torch.addmm, _addmm_flops_compute) - torch.Tensor.addmm = wrapFunc( - torch.Tensor.addmm, _tensor_addmm_flops_compute - ) + torch.Tensor.addmm = wrapFunc(torch.Tensor.addmm, _tensor_addmm_flops_compute) torch.mul = wrapFunc(torch.mul, _mul_flops_compute) torch.Tensor.mul = wrapFunc(torch.Tensor.mul, _mul_flops_compute) @@ -631,14 +608,10 @@ def _patch_tensor_methods(): torch.tanh = wrapFunc(torch.tanh, _tanh_flops_compute) - torch.Tensor.softmax = wrapFunc( - torch.Tensor.softmax, _softmax_flops_compute - ) + torch.Tensor.softmax = wrapFunc(torch.Tensor.softmax, _softmax_flops_compute) torch.sigmoid = wrapFunc(torch.sigmoid, _sigmoid_flops_compute) - torch.Tensor.sigmoid = wrapFunc( - torch.Tensor.sigmoid, _sigmoid_flops_compute - ) + torch.Tensor.sigmoid = wrapFunc(torch.Tensor.sigmoid, _sigmoid_flops_compute) def _reload_functionals(): @@ -732,15 +705,11 @@ def _rnn_flops(flops, rnn_module, w_ih, w_hh, input_size): flops += rnn_module.hidden_size * 4 # two hadamard _product and add for C state flops += ( - rnn_module.hidden_size - + rnn_module.hidden_size - + rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size ) # final hadamard flops += ( - rnn_module.hidden_size - + rnn_module.hidden_size - + rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size ) return flops diff --git a/icefall/rnn_lm/check-onnx-streaming.py b/icefall/rnn_lm/check-onnx-streaming.py index d51a4b76b..28b908f82 100755 --- a/icefall/rnn_lm/check-onnx-streaming.py +++ b/icefall/rnn_lm/check-onnx-streaming.py @@ -112,7 +112,6 @@ def main(): for torch_v, onnx_v in zip( (torch_log_prob, torch_h0, torch_c0), (onnx_log_prob, onnx_h0, onnx_c0) ): - assert torch.allclose(torch_v, onnx_v, atol=1e-5), ( torch_v.shape, onnx_v.shape, diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py index 3d206d139..0178b80bf 100755 --- a/icefall/rnn_lm/train.py +++ b/icefall/rnn_lm/train.py @@ -463,7 +463,6 @@ def train_one_epoch( cur_batch_idx = params.get("cur_batch_idx", 0) for batch_idx, batch in enumerate(train_dl): - if batch_idx < cur_batch_idx: continue cur_batch_idx = batch_idx diff --git a/icefall/shared/make_kn_lm.py b/icefall/shared/make_kn_lm.py index 7150297d6..231aca7f1 100755 --- a/icefall/shared/make_kn_lm.py +++ b/icefall/shared/make_kn_lm.py @@ -225,7 +225,6 @@ class NgramCounts: for n in range(0, self.ngram_order - 1): this_order_counts = self.counts[n] for hist, counts_for_hist in this_order_counts.items(): - n_star_star = 0 for w in counts_for_hist.word_to_count.keys(): n_star_star += len(counts_for_hist.word_to_context[w]) @@ -424,7 +423,6 @@ class NgramCounts: if __name__ == "__main__": - ngram_counts = NgramCounts(args.ngram_order) if args.text is None: diff --git a/icefall/transformer_lm/model.py b/icefall/transformer_lm/model.py index 79dda3168..c78cf1821 100644 --- a/icefall/transformer_lm/model.py +++ b/icefall/transformer_lm/model.py @@ -103,7 +103,6 @@ class TransformerLM(torch.nn.Module): return nll_loss def score_token(self, x: torch.Tensor, x_lens: torch.Tensor, state=None): - bs = x.size(0) state = None diff --git a/icefall/utils.py b/icefall/utils.py index 26e6936bb..410340d9d 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -263,6 +263,70 @@ def get_texts( return aux_labels.tolist() +def encode_supervisions_otc( + supervisions: dict, + subsampling_factor: int, + token_ids: Optional[List[List[int]]] = None, +) -> Tuple[torch.Tensor, Union[List[str], List[List[int]]]]: + """ + Encodes Lhotse's ``batch["supervisions"]`` dict into + a pair of torch Tensor, and a list of transcription strings or token indexes + + The supervision tensor has shape ``(batch_size, 3)``. + Its second dimension contains information about sequence index [0], + start frames [1] and num frames [2]. + + The batch items might become re-ordered during this operation -- the + returned tensor and list of strings are guaranteed to be consistent with + each other. + """ + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + torch.div( + supervisions["start_frame"], + subsampling_factor, + rounding_mode="floor", + ), + torch.div( + supervisions["num_frames"], + subsampling_factor, + rounding_mode="floor", + ), + ), + 1, + ).to(torch.int32) + + indices = torch.argsort(supervision_segments[:, 2], descending=True) + supervision_segments = supervision_segments[indices] + + ids = [] + verbatim_texts = [] + sorted_ids = [] + sorted_verbatim_texts = [] + + for cut in supervisions["cut"]: + id = cut.id + if hasattr(cut.supervisions[0], "verbatim_text"): + verbatim_text = cut.supervisions[0].verbatim_text + else: + verbatim_text = "" + ids.append(id) + verbatim_texts.append(verbatim_text) + + for index in indices.tolist(): + sorted_ids.append(ids[index]) + sorted_verbatim_texts.append(verbatim_texts[index]) + + if token_ids is None: + texts = supervisions["text"] + res = [texts[idx] for idx in indices] + else: + res = [token_ids[idx] for idx in indices] + + return supervision_segments, res, sorted_ids, sorted_verbatim_texts + + @dataclass class DecodingResults: # timestamps[i][k] contains the frame number on which tokens[i][k] diff --git a/requirements-ci.txt b/requirements-ci.txt index 2433e190b..1eba69764 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -20,9 +20,11 @@ kaldialign==0.7.1 sentencepiece==0.1.96 tensorboard==2.8.0 typeguard==2.13.3 +black==22.3.0 multi_quantization onnx onnxmltools onnxruntime kaldifst +kaldi-decoder diff --git a/requirements.txt b/requirements.txt index a07f6b7c7..5a8326619 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,9 @@ kaldifst kaldilm kaldialign +kaldi-decoder sentencepiece>=0.1.96 tensorboard typeguard dill +black==22.3.0