Merge branch 'master' into dev/cv-zipformer

This commit is contained in:
jinzr 2024-03-15 11:03:32 +08:00
commit 3560e2260e
119 changed files with 5673 additions and 212 deletions

1
.github/scripts/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
piper_phonemize.html

View File

@ -0,0 +1,29 @@
#!/usr/bin/env python3
def main():
prefix = (
"https://github.com/csukuangfj/piper-phonemize/releases/download/2023.12.5/"
)
files = [
"piper_phonemize-1.2.0-cp310-cp310-macosx_10_14_x86_64.whl",
"piper_phonemize-1.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
"piper_phonemize-1.2.0-cp311-cp311-macosx_10_14_x86_64.whl",
"piper_phonemize-1.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
"piper_phonemize-1.2.0-cp312-cp312-macosx_10_14_x86_64.whl",
"piper_phonemize-1.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
"piper_phonemize-1.2.0-cp37-cp37m-macosx_10_14_x86_64.whl",
"piper_phonemize-1.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
"piper_phonemize-1.2.0-cp38-cp38-macosx_10_14_x86_64.whl",
"piper_phonemize-1.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
"piper_phonemize-1.2.0-cp39-cp39-macosx_10_14_x86_64.whl",
"piper_phonemize-1.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
]
with open("piper_phonemize.html", "w") as f:
for file in files:
url = prefix + file
f.write(f'<a href="{url}">{file}</a><br/>\n')
if __name__ == "__main__":
main()

View File

@ -15,9 +15,9 @@ function prepare_data() {
# cause OOM error for CI later.
mkdir -p download/lm
pushd download/lm
wget -q http://www.openslr.org/resources/11/librispeech-vocab.txt
wget -q http://www.openslr.org/resources/11/librispeech-lexicon.txt
wget -q http://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz
wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-lm-norm.txt.gz
wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-lexicon.txt
wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-vocab.txt
ls -lh
gunzip librispeech-lm-norm.txt.gz

157
.github/scripts/ljspeech/TTS/run.sh vendored Executable file
View File

@ -0,0 +1,157 @@
#!/usr/bin/env bash
set -ex
python3 -m pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html
python3 -m pip install espnet_tts_frontend
python3 -m pip install numba
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/ljspeech/TTS
sed -i.bak s/600/8/g ./prepare.sh
sed -i.bak s/"first 100"/"first 3"/g ./prepare.sh
sed -i.bak s/500/5/g ./prepare.sh
git diff
function prepare_data() {
# We have created a subset of the data for testing
#
mkdir download
pushd download
wget -q https://huggingface.co/csukuangfj/ljspeech-subset-for-ci-test/resolve/main/LJSpeech-1.1.tar.bz2
tar xvf LJSpeech-1.1.tar.bz2
popd
./prepare.sh
tree .
}
function train() {
pushd ./vits
sed -i.bak s/200/3/g ./train.py
git diff .
popd
for t in low medium high; do
./vits/train.py \
--exp-dir vits/exp-$t \
--model-type $t \
--num-epochs 1 \
--save-every-n 1 \
--num-buckets 2 \
--tokens data/tokens.txt \
--max-duration 20
ls -lh vits/exp-$t
done
}
function infer() {
for t in low medium high; do
./vits/infer.py \
--num-buckets 2 \
--model-type $t \
--epoch 1 \
--exp-dir ./vits/exp-$t \
--tokens data/tokens.txt \
--max-duration 20
done
}
function export_onnx() {
for t in low medium high; do
./vits/export-onnx.py \
--model-type $t \
--epoch 1 \
--exp-dir ./vits/exp-$t \
--tokens data/tokens.txt
ls -lh vits/exp-$t/
done
}
function test_medium() {
git clone https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-medium-2024-03-12
./vits/export-onnx.py \
--model-type medium \
--epoch 820 \
--exp-dir ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp \
--tokens ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt
ls -lh ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp
./vits/test_onnx.py \
--model-filename ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp/vits-epoch-820.onnx \
--tokens ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt \
--output-filename /icefall/test-medium.wav
ls -lh /icefall/test-medium.wav
d=/icefall/vits-icefall-en_US-ljspeech-medium
mkdir $d
cp -v ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt $d/
cp -v ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp/vits-epoch-820.onnx $d/model.onnx
rm -rf icefall-tts-ljspeech-vits-medium-2024-03-12
pushd $d
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
tar xf espeak-ng-data.tar.bz2
rm espeak-ng-data.tar.bz2
cd ..
tar cjf vits-icefall-en_US-ljspeech-medium.tar.bz2 vits-icefall-en_US-ljspeech-medium
rm -rf vits-icefall-en_US-ljspeech-medium
ls -lh *.tar.bz2
popd
}
function test_low() {
git clone https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-low-2024-03-12
./vits/export-onnx.py \
--model-type low \
--epoch 1600 \
--exp-dir ./icefall-tts-ljspeech-vits-low-2024-03-12/exp \
--tokens ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt
ls -lh ./icefall-tts-ljspeech-vits-low-2024-03-12/exp
./vits/test_onnx.py \
--model-filename ./icefall-tts-ljspeech-vits-low-2024-03-12/exp/vits-epoch-1600.onnx \
--tokens ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt \
--output-filename /icefall/test-low.wav
ls -lh /icefall/test-low.wav
d=/icefall/vits-icefall-en_US-ljspeech-low
mkdir $d
cp -v ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt $d/
cp -v ./icefall-tts-ljspeech-vits-low-2024-03-12/exp/vits-epoch-1600.onnx $d/model.onnx
rm -rf icefall-tts-ljspeech-vits-low-2024-03-12
pushd $d
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
tar xf espeak-ng-data.tar.bz2
rm espeak-ng-data.tar.bz2
cd ..
tar cjf vits-icefall-en_US-ljspeech-low.tar.bz2 vits-icefall-en_US-ljspeech-low
rm -rf vits-icefall-en_US-ljspeech-low
ls -lh *.tar.bz2
popd
}
prepare_data
train
infer
export_onnx
rm -rf vits/exp-{low,medium,high}
test_medium
test_low

View File

@ -56,11 +56,14 @@ jobs:
- name: Build doc
shell: bash
run: |
.github/scripts/generate-piper-phonemize-page.py
cd docs
python3 -m pip install -r ./requirements.txt
make html
touch build/html/.nojekyll
cp -v ../piper_phonemize.html ./build/html/
- name: Deploy
uses: peaceiris/actions-gh-pages@v3
with:

102
.github/workflows/ljspeech.yml vendored Normal file
View File

@ -0,0 +1,102 @@
name: ljspeech
on:
push:
branches:
- master
pull_request:
branches:
- master
workflow_dispatch:
concurrency:
group: ljspeech-${{ github.ref }}
cancel-in-progress: true
jobs:
generate_build_matrix:
if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa'
# see https://github.com/pytorch/pytorch/pull/50633
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Generating build matrix
id: set-matrix
run: |
# outputting for debugging purposes
python ./.github/scripts/docker/generate_build_matrix.py
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py)
echo "::set-output name=matrix::${MATRIX}"
ljspeech:
needs: generate_build_matrix
name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Free space
shell: bash
run: |
ls -lh
df -h
rm -rf /opt/hostedtoolcache
df -h
echo "pwd: $PWD"
echo "github.workspace ${{ github.workspace }}"
- name: Run tests
uses: addnab/docker-run-action@v3
with:
image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }}
options: |
--volume ${{ github.workspace }}/:/icefall
shell: bash
run: |
export PYTHONPATH=/icefall:$PYTHONPATH
cd /icefall
git config --global --add safe.directory /icefall
.github/scripts/ljspeech/TTS/run.sh
- name: display files
shell: bash
run: |
ls -lh
- uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0'
with:
name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }}
path: ./*.wav
- uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0'
with:
name: generated-models-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}
path: ./*.wav
- name: Release exported onnx models
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0'
uses: svenstaro/upload-release-action@v2
with:
file_glob: true
overwrite: true
file: vits-icefall-*.tar.bz2
repo_name: k2-fsa/sherpa-onnx
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
tag: tts-models

View File

@ -0,0 +1,225 @@
Finetune from a pre-trained Zipformer model with adapters
=========================================================
This tutorial shows you how to fine-tune a pre-trained **Zipformer**
transducer model on a new dataset with adapters.
Adapters are compact and efficient module that can be integrated into a pre-trained model
to improve the model's performance on a new domain. Adapters are injected
between different modules in the well-trained neural network. During training, only the parameters
in the adapters will be updated. It achieves competitive performance
while requiring much less GPU memory than full fine-tuning. For more details about adapters,
please refer to the original `paper <https://arxiv.org/pdf/1902.00751.pdf#/>`_ for more details.
.. HINT::
We assume you have read the page :ref:`install icefall` and have setup
the environment for ``icefall``.
.. HINT::
We recommend you to use a GPU or several GPUs to run this recipe
For illustration purpose, we fine-tune the Zipformer transducer model
pre-trained on `LibriSpeech`_ on the small subset of `GigaSpeech`_. You could use your
own data for fine-tuning if you create a manifest for your new dataset.
Data preparation
----------------
Please follow the instructions in the `GigaSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/gigaspeech/ASR>`_
to prepare the fine-tune data used in this tutorial. We only require the small subset in GigaSpeech for this tutorial.
Model preparation
-----------------
We are using the Zipformer model trained on full LibriSpeech (960 hours) as the intialization. The
checkpoint of the model can be downloaded via the following command:
.. code-block:: bash
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
$ cd icefall-asr-librispeech-zipformer-2023-05-15/exp
$ git lfs pull --include "pretrained.pt"
$ ln -s pretrained.pt epoch-99.pt
$ cd ../data/lang_bpe_500
$ git lfs pull --include bpe.model
$ cd ../../..
Before fine-tuning, let's test the model's WER on the new domain. The following command performs
decoding on the GigaSpeech test sets:
.. code-block:: bash
./zipformer/decode_gigaspeech.py \
--epoch 99 \
--avg 1 \
--exp-dir icefall-asr-librispeech-zipformer-2023-05-15/exp \
--use-averaged-model 0 \
--max-duration 1000 \
--decoding-method greedy_search
You should see the following numbers:
.. code-block::
For dev, WER of different settings are:
greedy_search 20.06 best for dev
For test, WER of different settings are:
greedy_search 19.27 best for test
Fine-tune with adapter
----------------------
We insert 4 adapters with residual connection in each ``Zipformer2EncoderLayer``.
The original model parameters remain untouched during training and only the parameters of
the adapters are updated. The following command starts a fine-tuning experiment with adapters:
.. code-block:: bash
$ do_finetune=1
$ use_adapters=1
$ adapter_dim=8
$ ./zipformer_adapter/train.py \
--world-size 2 \
--num-epochs 20 \
--start-epoch 1 \
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
--use-fp16 1 \
--base-lr 0.045 \
--use-adapters $use_adapters --adapter-dim $adapter_dim \
--bpe-model data/lang_bpe_500/bpe.model \
--do-finetune $do_finetune \
--master-port 13022 \
--finetune-ckpt icefall-asr-librispeech-zipformer-2023-05-15/exp/pretrained.pt \
--max-duration 1000
The following arguments are related to fine-tuning:
- ``--do-finetune``
If True, do fine-tuning by initializing the model from a pre-trained checkpoint.
**Note that if you want to resume your fine-tuning experiment from certain epochs, you
need to set this to False.**
- ``use-adapters``
If adapters are used during fine-tuning.
- ``--adapter-dim``
The bottleneck dimension of the adapter module. Typically a small number.
You should notice that in the training log, the total number of trainale parameters is shown:
.. code-block::
2024-02-22 21:22:03,808 INFO [train.py:1277] A total of 761344 trainable parameters (1.148% of the whole model)
The trainable parameters only makes up 1.15% of the entire model parameters, so the training will be much faster
and requires less memory than full fine-tuning.
Decoding
--------
After training, let's test the WERs. To test the WERs on the GigaSpeech set,
you can execute the following command:
.. code-block:: bash
$ epoch=20
$ avg=10
$ use_adapters=1
$ adapter_dim=8
% ./zipformer/decode.py \
--epoch $epoch \
--avg $avg \
--use-averaged-model 1 \
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
--max-duration 600 \
--use-adapters $use_adapters \
--adapter-dim $adapter_dim \
--decoding-method greedy_search
You should see the following numbers:
.. code-block::
For dev, WER of different settings are:
greedy_search 15.44 best for dev
For test, WER of different settings are:
greedy_search 15.42 best for test
The WER on test set is improved from 19.27 to 15.42, demonstrating the effectiveness of adapters.
The same model can be used to perform decoding on LibriSpeech test sets. You can deactivate the adapters
to keep the same performance of the original model:
.. code-block:: bash
$ epoch=20
$ avg=1
$ use_adapters=0
$ adapter_dim=8
% ./zipformer/decode.py \
--epoch $epoch \
--avg $avg \
--use-averaged-model 1 \
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
--max-duration 600 \
--use-adapters $use_adapters \
--adapter-dim $adapter_dim \
--decoding-method greedy_search
.. code-block::
For dev, WER of different settings are:
greedy_search 2.23 best for test-clean
For test, WER of different settings are:
greedy_search 4.96 best for test-other
The numbers are the same as reported in `icefall <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md#normal-scaled-model-number-of-model-parameters-65549011-ie-6555-m>`_. So adapter-based
fine-tuning is also very flexible as the same model can be used for decoding on the original and target domain.
Export the model
----------------
After training, the model can be exported to ``onnx`` format easily using the following command:
.. code-block:: bash
$ use_adapters=1
$ adapter_dim=16
$ ./zipformer_adapter/export-onnx.py \
--tokens icefall-asr-librispeech-zipformer-2023-05-15/data/lang_bpe_500/tokens.txt \
--use-averaged-model 1 \
--epoch 20 \
--avg 10 \
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
--use-adapters $use_adapters \
--adapter-dim $adapter_dim \
--num-encoder-layers "2,2,3,4,3,2" \
--downsampling-factor "1,2,4,8,4,2" \
--feedforward-dim "512,768,1024,1536,1024,768" \
--num-heads "4,4,4,8,4,4" \
--encoder-dim "192,256,384,512,384,256" \
--query-head-dim 32 \
--value-head-dim 12 \
--pos-head-dim 4 \
--pos-dim 48 \
--encoder-unmasked-dim "192,192,256,256,256,192" \
--cnn-module-kernel "31,31,15,15,15,31" \
--decoder-dim 512 \
--joiner-dim 512 \
--causal False \
--chunk-size "16,32,64,-1" \
--left-context-frames "64,128,256,-1"

View File

@ -13,3 +13,4 @@ data to improve the performance on new domains.
:caption: Table of Contents
from_supervised/finetune_zipformer
adapter/finetune_adapter

View File

@ -13,6 +13,14 @@ with the `LJSpeech <https://keithito.com/LJ-Speech-Dataset/>`_ dataset.
The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech <https://arxiv.org/pdf/2106.06103.pdf>`_
Install extra dependencies
--------------------------
.. code-block:: bash
pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html
pip install numba espnet_tts_frontend
Data preparation
----------------
@ -56,7 +64,8 @@ Training
--start-epoch 1 \
--use-fp16 1 \
--exp-dir vits/exp \
--tokens data/tokens.txt
--tokens data/tokens.txt \
--model-type high \
--max-duration 500
.. note::
@ -64,6 +73,11 @@ Training
You can adjust the hyper-parameters to control the size of the VITS model and
the training configurations. For more details, please run ``./vits/train.py --help``.
.. warning::
If you want a model that runs faster on CPU, please use ``--model-type low``
or ``--model-type medium``.
.. note::
The training can take a long time (usually a couple of days).
@ -95,8 +109,8 @@ training part first. It will save the ground-truth and generated wavs to the dir
Export models
-------------
Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``:
``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``.
Currently we only support ONNX model exporting. It will generate one file in the given ``exp-dir``:
``vits-epoch-*.onnx``.
.. code-block:: bash
@ -120,4 +134,68 @@ Download pretrained models
If you don't want to train from scratch, you can download the pretrained models
by visiting the following link:
- `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2024-02-28>`_
- ``--model-type=high``: `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2024-02-28>`_
- ``--model-type=medium``: `<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-medium-2024-03-12>`_
- ``--model-type=low``: `<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-low-2024-03-12>`_
Usage in sherpa-onnx
--------------------
The following describes how to test the exported ONNX model in `sherpa-onnx`_.
.. hint::
`sherpa-onnx`_ supports different programming languages, e.g., C++, C, Python,
Kotlin, Java, Swift, Go, C#, etc. It also supports Android and iOS.
We only describe how to use pre-built binaries from `sherpa-onnx`_ below.
Please refer to `<https://k2-fsa.github.io/sherpa/onnx/>`_
for more documentation.
Install sherpa-onnx
^^^^^^^^^^^^^^^^^^^
.. code-block:: bash
pip install sherpa-onnx
To check that you have installed `sherpa-onnx`_ successfully, please run:
.. code-block:: bash
which sherpa-onnx-offline-tts
sherpa-onnx-offline-tts --help
Download lexicon files
^^^^^^^^^^^^^^^^^^^^^^
.. code-block:: bash
cd /tmp
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
tar xf espeak-ng-data.tar.bz2
Run sherpa-onnx
^^^^^^^^^^^^^^^
.. code-block:: bash
cd egs/ljspeech/TTS
sherpa-onnx-offline-tts \
--vits-model=vits/exp/vits-epoch-1000.onnx \
--vits-tokens=data/tokens.txt \
--vits-data-dir=/tmp/espeak-ng-data \
--num-threads=1 \
--output-filename=./high.wav \
"Ask not what your country can do for you; ask what you can do for your country."
.. hint::
You can also use ``sherpa-onnx-offline-tts-play`` to play the audio
as it is generating.
You should get a file ``high.wav`` after running the above command.
Congratulations! You have successfully trained and exported a text-to-speech
model and run it with `sherpa-onnx`_.

View File

@ -19,7 +19,9 @@ The following table lists the differences among them.
| `transducer_stateless_modified` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` |
| `transducer_stateless_modified-2` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` + extra data |
| `pruned_transducer_stateless3` | Conformer (reworked) | Embedding + Conv1d | pruned RNN-T + reworked model with random combiner + using aidatatang_20zh as extra data|
| `pruned_transducer_stateless7` | Zipformer | Embedding | pruned RNN-T + zipformer encoder + stateless decoder with context-size 1 |
| `pruned_transducer_stateless7` | Zipformer | Embedding | pruned RNN-T + zipformer encoder + stateless decoder with context-size set to 1 |
| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe with context-size set to 1 |
The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).

View File

@ -360,7 +360,7 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
fi
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
log "Stage 11: Train RNN LM model"
log "Stage 12: Train RNN LM model"
python ../../../icefall/rnn_lm/train.py \
--start-epoch 0 \
--world-size 1 \

View File

@ -89,6 +89,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
@ -881,9 +882,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error()
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0

View File

@ -85,6 +85,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import (
@ -878,9 +879,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0

View File

@ -78,6 +78,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@ -871,9 +872,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]

View File

@ -78,6 +78,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@ -882,9 +883,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]

View File

@ -78,6 +78,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@ -881,9 +882,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]

View File

@ -86,6 +86,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import (
@ -985,9 +986,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())

View File

@ -83,6 +83,7 @@ from icefall.checkpoint import (
update_averaged_model,
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@ -570,9 +571,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())

View File

@ -15,7 +15,7 @@ perturb_speed=true
#
# - $dl_dir/alimeeting
# This directory contains the following files downloaded from
# https://openslr.org/62/
# https://openslr.org/119/
#
# - Train_Ali_far.tar.gz
# - Train_Ali_near.tar.gz

View File

@ -12,7 +12,7 @@ use_gss=true # Use GSS-based enhancement with MDM setting
#
# - $dl_dir/alimeeting
# This directory contains the following files downloaded from
# https://openslr.org/62/
# https://openslr.org/119/
#
# - Train_Ali_far.tar.gz
# - Train_Ali_near.tar.gz

View File

@ -70,6 +70,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@ -851,9 +852,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]

View File

@ -69,6 +69,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@ -842,9 +843,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]

View File

@ -75,6 +75,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -1138,9 +1139,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]

View File

@ -75,6 +75,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -1129,9 +1130,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]

View File

@ -80,6 +80,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@ -894,9 +895,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]

View File

@ -890,9 +890,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise RuntimeError(f", exiting: {cur_grad_scale}")
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]

View File

@ -82,6 +82,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@ -966,9 +967,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]

View File

@ -79,6 +79,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@ -911,9 +912,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]

View File

@ -81,6 +81,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@ -909,9 +910,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]

View File

@ -81,6 +81,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@ -908,9 +909,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]

View File

@ -89,6 +89,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@ -1031,9 +1032,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())

View File

@ -100,6 +100,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@ -371,9 +372,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())

View File

@ -89,6 +89,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@ -1034,9 +1035,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())

View File

@ -85,6 +85,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -1169,9 +1170,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]

View File

@ -81,6 +81,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -1056,9 +1057,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]

View File

@ -232,7 +232,7 @@ class LibriHeavyAsrDataModule:
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")

View File

@ -93,6 +93,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@ -1036,9 +1037,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())

View File

@ -103,6 +103,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@ -1051,9 +1052,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())

View File

@ -117,6 +117,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@ -855,9 +856,9 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,
context_dim=4 * 768
if params.context_injection
else -1, # the output dim of text encoder
context_dim=(
4 * 768 if params.context_injection else -1
), # the output dim of text encoder
context_injection=params.context_injection,
)
return joiner
@ -1398,9 +1399,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())

View File

@ -80,6 +80,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@ -976,9 +977,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]

View File

@ -81,6 +81,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@ -878,9 +879,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]

View File

@ -81,6 +81,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@ -902,9 +903,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]

View File

@ -77,6 +77,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@ -891,9 +892,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]

View File

@ -80,6 +80,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@ -880,9 +881,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]

View File

@ -80,6 +80,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@ -879,9 +880,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]

View File

@ -84,6 +84,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@ -946,9 +947,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]

View File

@ -89,6 +89,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@ -946,9 +947,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]

View File

@ -66,6 +66,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import UniqLexicon
from icefall.utils import (
@ -883,9 +884,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]

View File

@ -92,6 +92,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@ -1122,9 +1123,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())

View File

@ -90,6 +90,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@ -1021,9 +1022,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())

View File

@ -81,6 +81,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@ -1125,9 +1126,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())

View File

@ -62,6 +62,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@ -797,9 +798,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]

View File

@ -79,6 +79,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon, UniqLexicon
from icefall.mmi import LFMMILoss
@ -816,9 +817,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]

View File

@ -36,3 +36,68 @@ To inference, use:
--epoch 1000 \
--tokens data/tokens.txt
```
## Quality vs speed
If you feel that the trained model is slow at runtime, you can specify the
argument `--model-type` during training. Possible values are:
- `low`, means **low** quality. The resulting model is very small in file size
and runs very fast. The following is a wave file generatd by a `low` quality model
https://github.com/k2-fsa/icefall/assets/5284924/d5758c24-470d-40ee-b089-e57fcba81633
The text is `Ask not what your country can do for you; ask what you can do for your country.`
The exported onnx model has a file size of ``26.8 MB`` (float32).
- `medium`, means **medium** quality.
The following is a wave file generatd by a `medium` quality model
https://github.com/k2-fsa/icefall/assets/5284924/b199d960-3665-4d0d-9ae9-a1bb69cbc8ac
The text is `Ask not what your country can do for you; ask what you can do for your country.`
The exported onnx model has a file size of ``70.9 MB`` (float32).
- `high`, means **high** quality. This is the default value.
The following is a wave file generatd by a `high` quality model
https://github.com/k2-fsa/icefall/assets/5284924/b39f3048-73a6-4267-bf95-df5abfdb28fc
The text is `Ask not what your country can do for you; ask what you can do for your country.`
The exported onnx model has a file size of ``113 MB`` (float32).
A pre-trained `low` model trained using 4xV100 32GB GPU with the following command can be found at
<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-low-2024-03-12>
```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3
./vits/train.py \
--world-size 4 \
--num-epochs 1601 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir vits/exp \
--model-type low \
--max-duration 800
```
A pre-trained `medium` model trained using 4xV100 32GB GPU with the following command can be found at
<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-medium-2024-03-12>
```bash
export CUDA_VISIBLE_DEVICES=4,5,6,7
./vits/train.py \
--world-size 4 \
--num-epochs 1000 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir vits/exp-medium \
--model-type medium \
--max-duration 500
# (Note it is killed after `epoch-820.pt`)
```

View File

@ -23,7 +23,11 @@ This file reads the texts in given manifest and save the new cuts with phoneme t
import logging
from pathlib import Path
import tacotron_cleaner.cleaners
try:
import tacotron_cleaner.cleaners
except ModuleNotFoundError as ex:
raise RuntimeError(f"{ex}\nPlease run\n pip install espnet_tts_frontend\n")
from lhotse import CutSet, load_manifest
from piper_phonemize import phonemize_espeak

View File

@ -28,7 +28,7 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "Stage -1: build monotonic_align lib"
if [ ! -d vits/monotonic_align/build ]; then
cd vits/monotonic_align
python setup.py build_ext --inplace
python3 setup.py build_ext --inplace
cd ../../
else
log "monotonic_align lib already built"
@ -54,7 +54,7 @@ fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare LJSpeech manifest"
# We assume that you have downloaded the LJSpeech corpus
# to $dl_dir/LJSpeech
# to $dl_dir/LJSpeech-1.1
mkdir -p data/manifests
if [ ! -e data/manifests/.ljspeech.done ]; then
lhotse prepare ljspeech $dl_dir/LJSpeech-1.1 data/manifests
@ -82,8 +82,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare phoneme tokens for LJSpeech"
# We assume you have installed piper_phonemize and espnet_tts_frontend.
# If not, please install them with:
# - piper_phonemize: refer to https://github.com/rhasspy/piper-phonemize,
# could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
# - piper_phonemize: pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html,
# - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then
./local/prepare_tokens_ljspeech.py

View File

@ -25,9 +25,8 @@ Export the model to ONNX:
--exp-dir vits/exp \
--tokens data/tokens.txt
It will generate two files inside vits/exp:
It will generate one file inside vits/exp:
- vits-epoch-1000.onnx
- vits-epoch-1000.int8.onnx (quantizated model)
See ./test_onnx.py for how to use the exported ONNX models.
"""
@ -40,7 +39,6 @@ from typing import Dict, Tuple
import onnx
import torch
import torch.nn as nn
from onnxruntime.quantization import QuantType, quantize_dynamic
from tokenizer import Tokenizer
from train import get_model, get_params
@ -75,6 +73,16 @@ def get_parser():
help="""Path to vocabulary.""",
)
parser.add_argument(
"--model-type",
type=str,
default="high",
choices=["low", "medium", "high"],
help="""If not empty, valid values are: low, medium, high.
It controls the model size. low -> runs faster.
""",
)
return parser
@ -136,7 +144,7 @@ class OnnxModel(nn.Module):
Return a tuple containing:
- audio, generated wavform tensor, (B, T_wav)
"""
audio, _, _ = self.model.inference(
audio, _, _ = self.model.generator.inference(
text=tokens,
text_lengths=tokens_lens,
noise_scale=noise_scale,
@ -198,6 +206,11 @@ def export_model_onnx(
},
)
if model.model.spks is None:
num_speakers = 1
else:
num_speakers = model.model.spks
meta_data = {
"model_type": "vits",
"version": "1",
@ -206,8 +219,8 @@ def export_model_onnx(
"language": "English",
"voice": "en-us", # Choose your language appropriately
"has_espeak": 1,
"n_speakers": 1,
"sample_rate": 22050, # Must match the real sample rate
"n_speakers": num_speakers,
"sample_rate": model.model.sampling_rate, # Must match the real sample rate
}
logging.info(f"meta_data: {meta_data}")
@ -233,14 +246,13 @@ def main():
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model = model.generator
model.to("cpu")
model.eval()
model = OnnxModel(model=model)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"generator parameters: {num_param}")
logging.info(f"generator parameters: {num_param}, or {num_param/1000/1000} M")
suffix = f"epoch-{params.epoch}"
@ -256,18 +268,6 @@ def main():
)
logging.info(f"Exported generator to {model_filename}")
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
logging.info("Generate int8 quantization models")
model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx"
quantize_dynamic(
model_input=model_filename,
model_output=model_filename_int8,
weight_type=QuantType.QUInt8,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

View File

@ -189,7 +189,7 @@ class VITSGenerator(torch.nn.Module):
self.upsample_factor = int(np.prod(decoder_upsample_scales))
self.spks = None
if spks is not None and spks > 1:
assert global_channels > 0
assert global_channels > 0, global_channels
self.spks = spks
self.global_emb = torch.nn.Embedding(spks, global_channels)
self.spk_embed_dim = None

View File

@ -72,6 +72,16 @@ def get_parser():
help="""Path to vocabulary.""",
)
parser.add_argument(
"--model-type",
type=str,
default="high",
choices=["low", "medium", "high"],
help="""If not empty, valid values are: low, medium, high.
It controls the model size. low -> runs faster.
""",
)
return parser
@ -94,6 +104,7 @@ def infer_dataset(
tokenizer:
Used to convert text to phonemes.
"""
# Background worker save audios to disk.
def _save_worker(
batch_size: int,

View File

@ -10,7 +10,11 @@ import warnings
import numpy as np
import torch
from numba import njit, prange
try:
from numba import njit, prange
except ModuleNotFoundError as ex:
raise RuntimeError(f"{ex}/nPlease run\n pip install numba")
try:
from .core import maximum_path_c

View File

@ -0,0 +1,50 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tokenizer import Tokenizer
from train import get_model, get_params
from vits import VITS
def test_model_type(model_type):
tokens = "./data/tokens.txt"
params = get_params()
tokenizer = Tokenizer(tokens)
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
params.model_type = model_type
model = get_model(params)
generator = model.generator
num_param = sum([p.numel() for p in generator.parameters()])
print(
f"{model_type}: generator parameters: {num_param}, or {num_param/1000/1000} M"
)
def main():
test_model_type("high") # 35.63 M
test_model_type("low") # 7.55 M
test_model_type("medium") # 23.61 M
if __name__ == "__main__":
main()

View File

@ -54,6 +54,20 @@ def get_parser():
help="""Path to vocabulary.""",
)
parser.add_argument(
"--text",
type=str,
default="Ask not what your country can do for you; ask what you can do for your country.",
help="Text to generate speech for",
)
parser.add_argument(
"--output-filename",
type=str,
default="test_onnx.wav",
help="Filename to save the generated wave file.",
)
return parser
@ -61,7 +75,7 @@ class OnnxModel:
def __init__(self, model_filename: str):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 4
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
@ -72,6 +86,9 @@ class OnnxModel:
)
logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
metadata = self.model.get_modelmeta().custom_metadata_map
self.sample_rate = int(metadata["sample_rate"])
def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor:
"""
Args:
@ -101,13 +118,14 @@ class OnnxModel:
def main():
args = get_parser().parse_args()
logging.info(vars(args))
tokenizer = Tokenizer(args.tokens)
logging.info("About to create onnx model")
model = OnnxModel(args.model_filename)
text = "I went there to see the land, the people and how their system works, end quote."
text = args.text
tokens = tokenizer.texts_to_token_ids(
[text], intersperse_blank=True, add_sos=True, add_eos=True
)
@ -115,8 +133,9 @@ def main():
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
audio = model(tokens, tokens_lens) # (1, T')
torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050)
logging.info("Saved to test_onnx.wav")
output_filename = args.output_filename
torchaudio.save(output_filename, audio, sample_rate=model.sample_rate)
logging.info(f"Saved to {output_filename}")
if __name__ == "__main__":

View File

@ -92,9 +92,9 @@ class TextEncoder(torch.nn.Module):
x_lengths (Tensor): Length tensor (B,).
Returns:
Tensor: Encoded hidden representation (B, attention_dim, T_text).
Tensor: Projected mean tensor (B, attention_dim, T_text).
Tensor: Projected scale tensor (B, attention_dim, T_text).
Tensor: Encoded hidden representation (B, embed_dim, T_text).
Tensor: Projected mean tensor (B, embed_dim, T_text).
Tensor: Projected scale tensor (B, embed_dim, T_text).
Tensor: Mask tensor for input tensor (B, 1, T_text).
"""
@ -108,6 +108,7 @@ class TextEncoder(torch.nn.Module):
# encoder assume the channel last (B, T_text, embed_dim)
x = self.encoder(x, key_padding_mask=pad_mask)
# Note: attention_dim == embed_dim
# convert the channel first (B, embed_dim, T_text)
x = x.transpose(1, 2)

View File

@ -18,7 +18,15 @@ import logging
from typing import Dict, List
import tacotron_cleaner.cleaners
from piper_phonemize import phonemize_espeak
try:
from piper_phonemize import phonemize_espeak
except Exception as ex:
raise RuntimeError(
f"{ex}\nPlease run\n"
"pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html"
)
from utils import intersperse

View File

@ -153,6 +153,16 @@ def get_parser():
help="Whether to use half precision training.",
)
parser.add_argument(
"--model-type",
type=str,
default="high",
choices=["low", "medium", "high"],
help="""If not empty, valid values are: low, medium, high.
It controls the model size. low -> runs faster.
""",
)
return parser
@ -189,15 +199,6 @@ def get_params() -> AttributeDict:
- feature_dim: The model input dim. It has to match the one used
in computing features.
- subsampling_factor: The subsampling factor for the model.
- encoder_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
- warm_step: The warmup period that dictates the decay of the
scale on "simple" (un-pruned) loss.
"""
params = AttributeDict(
{
@ -278,6 +279,7 @@ def get_model(params: AttributeDict) -> nn.Module:
vocab_size=params.vocab_size,
feature_dim=params.feature_dim,
sampling_rate=params.sampling_rate,
model_type=params.model_type,
mel_loss_params=mel_loss_params,
lambda_adv=params.lambda_adv,
lambda_mel=params.lambda_mel,
@ -363,7 +365,7 @@ def train_one_epoch(
model.train()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
# used to summary the stats over iterations in one epoch
# used to track the stats over iterations in one epoch
tot_loss = MetricsTracker()
saved_bad_model = False

View File

@ -255,6 +255,7 @@ class LJSpeechTtsDataModule:
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
num_buckets=self.args.num_buckets,
shuffle=False,
)
logging.info("About to create valid dataloader")
@ -294,6 +295,7 @@ class LJSpeechTtsDataModule:
test_sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
num_buckets=self.args.num_buckets,
shuffle=False,
)
logging.info("About to create test dataloader")

View File

@ -5,6 +5,7 @@
"""VITS module for GAN-TTS task."""
import copy
from typing import Any, Dict, Optional, Tuple
import torch
@ -38,6 +39,36 @@ AVAILABLE_DISCRIMINATORS = {
"hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA
}
LOW_CONFIG = {
"hidden_channels": 96,
"decoder_upsample_scales": (8, 8, 4),
"decoder_channels": 256,
"decoder_upsample_kernel_sizes": (16, 16, 8),
"decoder_resblock_kernel_sizes": (3, 5, 7),
"decoder_resblock_dilations": ((1, 2), (2, 6), (3, 12)),
"text_encoder_cnn_module_kernel": 3,
}
MEDIUM_CONFIG = {
"hidden_channels": 192,
"decoder_upsample_scales": (8, 8, 4),
"decoder_channels": 256,
"decoder_upsample_kernel_sizes": (16, 16, 8),
"decoder_resblock_kernel_sizes": (3, 5, 7),
"decoder_resblock_dilations": ((1, 2), (2, 6), (3, 12)),
"text_encoder_cnn_module_kernel": 3,
}
HIGH_CONFIG = {
"hidden_channels": 192,
"decoder_upsample_scales": (8, 8, 2, 2),
"decoder_channels": 512,
"decoder_upsample_kernel_sizes": (16, 16, 4, 4),
"decoder_resblock_kernel_sizes": (3, 7, 11),
"decoder_resblock_dilations": ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
"text_encoder_cnn_module_kernel": 5,
}
class VITS(nn.Module):
"""Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`"""
@ -49,6 +80,7 @@ class VITS(nn.Module):
feature_dim: int = 513,
sampling_rate: int = 22050,
generator_type: str = "vits_generator",
model_type: str = "",
generator_params: Dict[str, Any] = {
"hidden_channels": 192,
"spks": None,
@ -155,12 +187,13 @@ class VITS(nn.Module):
"""Initialize VITS module.
Args:
idim (int): Input vocabrary size.
idim (int): Input vocabulary size.
odim (int): Acoustic feature dimension. The actual output channels will
be 1 since VITS is the end-to-end text-to-wave model but for the
compatibility odim is used to indicate the acoustic feature dimension.
sampling_rate (int): Sampling rate, not used for the training but it will
be referred in saving waveform during the inference.
model_type (str): If not empty, must be one of: low, medium, high
generator_type (str): Generator type.
generator_params (Dict[str, Any]): Parameter dict for generator.
discriminator_type (str): Discriminator type.
@ -181,6 +214,24 @@ class VITS(nn.Module):
"""
super().__init__()
generator_params = copy.deepcopy(generator_params)
discriminator_params = copy.deepcopy(discriminator_params)
generator_adv_loss_params = copy.deepcopy(generator_adv_loss_params)
discriminator_adv_loss_params = copy.deepcopy(discriminator_adv_loss_params)
feat_match_loss_params = copy.deepcopy(feat_match_loss_params)
mel_loss_params = copy.deepcopy(mel_loss_params)
if model_type != "":
assert model_type in ("low", "medium", "high"), model_type
if model_type == "low":
generator_params.update(LOW_CONFIG)
elif model_type == "medium":
generator_params.update(MEDIUM_CONFIG)
elif model_type == "high":
generator_params.update(HIGH_CONFIG)
else:
raise ValueError(f"Unknown model_type: ${model_type}")
# define modules
generator_class = AVAILABLE_GENERATERS[generator_type]
if generator_type == "vits_generator":

19
egs/mdcc/ASR/README.md Normal file
View File

@ -0,0 +1,19 @@
# Introduction
Multi-Domain Cantonese Corpus (MDCC), consists of 73.6 hours of clean read speech paired with
transcripts, collected from Cantonese audiobooks from Hong Kong. It comprises philosophy,
politics, education, culture, lifestyle and family domains, covering a wide range of topics.
Manuscript can be found at: https://arxiv.org/abs/2201.02419
# Transducers
| | Encoder | Decoder | Comment |
|---------------------------------------|---------------------|--------------------|-----------------------------|
| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe with context-size set to 1 |
The decoder is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
We place an additional Conv1d layer right after the input embedding layer.

41
egs/mdcc/ASR/RESULTS.md Normal file
View File

@ -0,0 +1,41 @@
## Results
#### Zipformer
See <https://github.com/k2-fsa/icefall/pull/1537>
[./zipformer](./zipformer)
##### normal-scaled model, number of model parameters: 74470867, i.e., 74.47 M
| | test | valid | comment |
|------------------------|------|-------|-----------------------------------------|
| greedy search | 7.45 | 7.51 | --epoch 45 --avg 35 |
| modified beam search | 6.68 | 6.73 | --epoch 45 --avg 35 |
| fast beam search | 7.22 | 7.28 | --epoch 45 --avg 35 |
The training command:
```
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./zipformer/train.py \
--world-size 4 \
--start-epoch 1 \
--num-epochs 50 \
--use-fp16 1 \
--exp-dir ./zipformer/exp \
--max-duration 1000
```
The decoding command:
```
./zipformer/decode.py \
--epoch 45 \
--avg 35 \
--exp-dir ./zipformer/exp \
--decoding-method greedy_search # modified_beam_search
```
The pretrained model is available at: https://huggingface.co/zrjin/icefall-asr-mdcc-zipformer-2024-03-11/

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/compile_hlg.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/compile_hlg_using_openfst.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/compile_lg.py

View File

@ -0,0 +1,157 @@
#!/usr/bin/env python3
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the aishell dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
import os
from pathlib import Path
import torch
from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_mdcc(
num_mel_bins: int = 80,
perturb_speed: bool = False,
whisper_fbank: bool = False,
output_dir: str = "data/fbank",
):
src_dir = Path("data/manifests")
output_dir = Path(output_dir)
num_jobs = min(15, os.cpu_count())
dataset_parts = (
"train",
"valid",
"test",
)
prefix = "mdcc"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
if whisper_fbank:
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if "train" in partition and perturb_speed:
logging.info("Doing speed perturb")
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-mel-bins",
type=int,
default=80,
help="""The number of mel bins for Fbank""",
)
parser.add_argument(
"--perturb-speed",
type=str2bool,
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
parser.add_argument(
"--output-dir",
type=str,
default="data/fbank",
help="Output directory. Default: data/fbank.",
)
return parser.parse_args()
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
compute_fbank_mdcc(
num_mel_bins=args.num_mel_bins,
perturb_speed=args.perturb_speed,
whisper_fbank=args.whisper_fbank,
output_dir=args.output_dir,
)

View File

@ -0,0 +1,144 @@
#!/usr/bin/env python3
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file displays duration statistics of utterances in a manifest.
You can use the displayed value to choose minimum/maximum duration
to remove short and long utterances during the training.
See the function `remove_short_and_long_utt()` in transducer/train.py
for usage.
"""
from lhotse import load_manifest_lazy
def main():
path = "./data/fbank/mdcc_cuts_train.jsonl.gz"
path = "./data/fbank/mdcc_cuts_valid.jsonl.gz"
path = "./data/fbank/mdcc_cuts_test.jsonl.gz"
cuts = load_manifest_lazy(path)
cuts.describe(full=True)
if __name__ == "__main__":
main()
"""
data/fbank/mdcc_cuts_train.jsonl.gz (with speed perturbation)
_________________________________________
_ Cuts count: _ 195360
_________________________________________
_ Total duration (hh:mm:ss) _ 173:44:59
_________________________________________
_ mean _ 3.2
_________________________________________
_ std _ 2.1
_________________________________________
_ min _ 0.2
_________________________________________
_ 25% _ 1.8
_________________________________________
_ 50% _ 2.7
_________________________________________
_ 75% _ 4.0
_________________________________________
_ 99% _ 11.0 _
_________________________________________
_ 99.5% _ 12.4 _
_________________________________________
_ 99.9% _ 14.8 _
_________________________________________
_ max _ 16.7 _
_________________________________________
_ Recordings available: _ 195360 _
_________________________________________
_ Features available: _ 195360 _
_________________________________________
_ Supervisions available: _ 195360 _
_________________________________________
data/fbank/mdcc_cuts_valid.jsonl.gz
________________________________________
_ Cuts count: _ 5663 _
________________________________________
_ Total duration (hh:mm:ss) _ 05:03:12 _
________________________________________
_ mean _ 3.2 _
________________________________________
_ std _ 2.0 _
________________________________________
_ min _ 0.3 _
________________________________________
_ 25% _ 1.8 _
________________________________________
_ 50% _ 2.7 _
________________________________________
_ 75% _ 4.0 _
________________________________________
_ 99% _ 10.9 _
________________________________________
_ 99.5% _ 12.3 _
________________________________________
_ 99.9% _ 14.4 _
________________________________________
_ max _ 14.8 _
________________________________________
_ Recordings available: _ 5663 _
________________________________________
_ Features available: _ 5663 _
________________________________________
_ Supervisions available: _ 5663 _
________________________________________
data/fbank/mdcc_cuts_test.jsonl.gz
________________________________________
_ Cuts count: _ 12492 _
________________________________________
_ Total duration (hh:mm:ss) _ 11:00:31 _
________________________________________
_ mean _ 3.2 _
________________________________________
_ std _ 2.0 _
________________________________________
_ min _ 0.2 _
________________________________________
_ 25% _ 1.8 _
________________________________________
_ 50% _ 2.7 _
________________________________________
_ 75% _ 4.0 _
________________________________________
_ 99% _ 10.5 _
________________________________________
_ 99.5% _ 12.1 _
________________________________________
_ 99.9% _ 14.0 _
________________________________________
_ max _ 14.8 _
________________________________________
_ Recordings available: _ 12492 _
________________________________________
_ Features available: _ 12492 _
________________________________________
_ Supervisions available: _ 12492 _
________________________________________
"""

View File

@ -0,0 +1 @@
../../../aishell/ASR/local/prepare_char.py

View File

@ -0,0 +1 @@
../../../aishell/ASR/local/prepare_char_lm_training_data.py

View File

@ -0,0 +1 @@
../../../aishell/ASR/local/prepare_lang.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/prepare_lang_fst.py

View File

@ -0,0 +1,157 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Zengrui Jin)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes a text file "data/lang_char/text" as input, the file consist of
lines each containing a transcript, applies text norm and generates the following
files in the directory "data/lang_char":
- text_norm
- words.txt
- words_no_ids.txt
- text_words_segmentation
"""
import argparse
import logging
from pathlib import Path
from typing import List
import pycantonese
from tqdm.auto import tqdm
from icefall.utils import is_cjk
def get_parser():
parser = argparse.ArgumentParser(
description="Prepare char lexicon",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--input-file",
"-i",
default="data/lang_char/text",
type=str,
help="The input text file",
)
parser.add_argument(
"--output-dir",
"-o",
default="data/lang_char",
type=str,
help="The output directory",
)
return parser
def get_norm_lines(lines: List[str]) -> List[str]:
def _text_norm(text: str) -> str:
# to cope with the protocol for transcription:
# When taking notes, the annotators adhere to the following guidelines:
# 1) If the audio contains pure music, the annotators mark the label
# "(music)" in the file name of its transcript. 2) If the utterance
# contains one or several sentences with background music or noise, the
# annotators mark the label "(music)" before each sentence in the transcript.
# 3) The annotators use {} symbols to enclose words they are uncertain
# about, for example, {梁佳佳},我是{}人.
# here we manually fix some errors in the transcript
return (
text.strip()
.replace("(music)", "")
.replace("(music", "")
.replace("{", "")
.replace("}", "")
.replace("BB所以就指腹為親喇", "BB 所以就指腹為親喇")
.upper()
)
return [_text_norm(line) for line in lines]
def get_word_segments(lines: List[str]) -> List[str]:
# the current pycantonese segmenter does not handle the case when the input
# is code switching, so we need to handle it separately
new_lines = []
for line in tqdm(lines, desc="Segmenting lines"):
try:
# code switching
if len(line.strip().split(" ")) > 1:
segments = []
for segment in line.strip().split(" "):
if segment.strip() == "":
continue
try:
if not is_cjk(segment[0]): # en segment
segments.append(segment)
else: # zh segment
segments.extend(pycantonese.segment(segment))
except Exception as e:
logging.error(f"Failed to process segment: {segment}")
raise e
new_lines.append(" ".join(segments) + "\n")
# not code switching
else:
new_lines.append(" ".join(pycantonese.segment(line)) + "\n")
except Exception as e:
logging.error(f"Failed to process line: {line}")
raise e
return new_lines
def get_words(lines: List[str]) -> List[str]:
words = set()
for line in tqdm(lines, desc="Getting words"):
words.update(line.strip().split(" "))
return list(words)
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
input_file = Path(args.input_file)
output_dir = Path(args.output_dir)
assert output_dir.is_dir(), f"{output_dir} does not exist"
assert input_file.is_file(), f"{input_file} does not exist"
lines = input_file.read_text(encoding="utf-8").strip().split("\n")
norm_lines = get_norm_lines(lines)
with open(output_dir / "text_norm", "w+", encoding="utf-8") as f:
f.writelines([line + "\n" for line in norm_lines])
text_words_segments = get_word_segments(norm_lines)
with open(output_dir / "text_words_segmentation", "w+", encoding="utf-8") as f:
f.writelines(text_words_segments)
words = get_words(text_words_segments)[1:] # remove "\n" from words
with open(output_dir / "words_no_ids.txt", "w+", encoding="utf-8") as f:
f.writelines([word + "\n" for word in sorted(words)])
words = (
["<eps>", "!SIL", "<SPOKEN_NOISE>", "<UNK>"]
+ sorted(words)
+ ["#0", "<s>", "<\s>"]
)
with open(output_dir / "words.txt", "w+", encoding="utf-8") as f:
f.writelines([f"{word} {i}\n" for i, word in enumerate(words)])

View File

@ -0,0 +1,86 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo)
# 2022 Xiaomi Corp. (authors: Weiji Zhuang)
# 2024 Xiaomi Corp. (authors: Zengrui Jin)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input "text", which refers to the transcript file for
MDCC:
- text
and generates the output file text_word_segmentation which is implemented
with word segmenting:
- text_words_segmentation
"""
import argparse
from typing import List
import pycantonese
from tqdm.auto import tqdm
def get_parser():
parser = argparse.ArgumentParser(
description="Cantonese Word Segmentation for text",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--input-file",
"-i",
default="data/lang_char/text",
type=str,
help="the input text file for MDCC",
)
parser.add_argument(
"--output-file",
"-o",
default="data/lang_char/text_words_segmentation",
type=str,
help="the text implemented with words segmenting for MDCC",
)
return parser
def get_word_segments(lines: List[str]) -> List[str]:
return [
" ".join(pycantonese.segment(line)) + "\n"
for line in tqdm(lines, desc="Segmenting lines")
]
def main():
parser = get_parser()
args = parser.parse_args()
input_file = args.input_file
output_file = args.output_file
with open(input_file, "r", encoding="utf-8") as fr:
lines = fr.readlines()
new_lines = get_word_segments(lines)
with open(output_file, "w", encoding="utf-8") as fw:
fw.writelines(new_lines)
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../aidatatang_200zh/ASR/local/text2token.py

308
egs/mdcc/ASR/prepare.sh Executable file
View File

@ -0,0 +1,308 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
stage=-1
stop_stage=100
perturb_speed=true
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
#
# - $dl_dir/mdcc
# |-- README.md
# |-- audio/
# |-- clip_info_rthk.csv
# |-- cnt_asr_metadata_full.csv
# |-- cnt_asr_test_metadata.csv
# |-- cnt_asr_train_metadata.csv
# |-- cnt_asr_valid_metadata.csv
# |-- data_statistic.py
# |-- length
# |-- podcast_447_2021.csv
# |-- test.txt
# |-- transcription/
# `-- words_length
# You can download them from:
# https://drive.google.com/file/d/1epfYMMhXdBKA6nxPgUugb2Uj4DllSxkn/view?usp=drive_link
#
# - $dl_dir/musan
# This directory contains the following directories downloaded from
# http://www.openslr.org/17/
#
# - music
# - noise
# - speech
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "stage 0: Download data"
# If you have pre-downloaded it to /path/to/mdcc,
# you can create a symlink
#
# ln -sfv /path/to/mdcc $dl_dir/mdcc
#
# The directory structure is
# mdcc/
# |-- README.md
# |-- audio/
# |-- clip_info_rthk.csv
# |-- cnt_asr_metadata_full.csv
# |-- cnt_asr_test_metadata.csv
# |-- cnt_asr_train_metadata.csv
# |-- cnt_asr_valid_metadata.csv
# |-- data_statistic.py
# |-- length
# |-- podcast_447_2021.csv
# |-- test.txt
# |-- transcription/
# `-- words_length
if [ ! -d $dl_dir/mdcc/audio ]; then
lhotse download mdcc $dl_dir
# this will download and unzip dataset.zip to $dl_dir/
mv $dl_dir/dataset $dl_dir/mdcc
fi
# If you have pre-downloaded it to /path/to/musan,
# you can create a symlink
#
# ln -sfv /path/to/musan $dl_dir/musan
#
if [ ! -d $dl_dir/musan ]; then
lhotse download musan $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare MDCC manifest"
# We assume that you have downloaded the MDCC corpus
# to $dl_dir/mdcc
if [ ! -f data/manifests/.mdcc_manifests.done ]; then
log "Might take 40 minutes to traverse the directory."
mkdir -p data/manifests
lhotse prepare mdcc $dl_dir/mdcc data/manifests
touch data/manifests/.mdcc_manifests.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Prepare musan manifest"
# We assume that you have downloaded the musan corpus
# to data/musan
if [ ! -f data/manifests/.musan_manifests.done ]; then
log "It may take 6 minutes"
mkdir -p data/manifests
lhotse prepare musan $dl_dir/musan data/manifests
touch data/manifests/.musan_manifests.done
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compute fbank for MDCC"
if [ ! -f data/fbank/.mdcc.done ]; then
mkdir -p data/fbank
./local/compute_fbank_mdcc.py --perturb-speed ${perturb_speed}
touch data/fbank/.mdcc.done
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank for musan"
if [ ! -f data/fbank/.msuan.done ]; then
mkdir -p data/fbank
./local/compute_fbank_musan.py
touch data/fbank/.msuan.done
fi
fi
lang_char_dir=data/lang_char
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Prepare char based lang"
mkdir -p $lang_char_dir
# Prepare text.
# Note: in Linux, you can install jq with the following command:
# 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
# 2. chmod +x ./jq
# 3. cp jq /usr/bin
if [ ! -f $lang_char_dir/text ]; then
gunzip -c data/manifests/mdcc_supervisions_train.jsonl.gz \
|jq '.text' | sed 's/"//g' | ./local/text2token.py -t "char" \
> $lang_char_dir/train_text
cat $lang_char_dir/train_text > $lang_char_dir/text
gunzip -c data/manifests/mdcc_supervisions_test.jsonl.gz \
|jq '.text' | sed 's/"//g' | ./local/text2token.py -t "char" \
> $lang_char_dir/valid_text
cat $lang_char_dir/valid_text >> $lang_char_dir/text
gunzip -c data/manifests/mdcc_supervisions_valid.jsonl.gz \
|jq '.text' | sed 's/"//g' | ./local/text2token.py -t "char" \
> $lang_char_dir/test_text
cat $lang_char_dir/test_text >> $lang_char_dir/text
fi
if [ ! -f $lang_char_dir/text_words_segmentation ]; then
./local/preprocess_mdcc.py --input-file $lang_char_dir/text \
--output-dir $lang_char_dir
mv $lang_char_dir/text $lang_char_dir/_text
cp $lang_char_dir/text_words_segmentation $lang_char_dir/text
fi
if [ ! -f $lang_char_dir/tokens.txt ]; then
./local/prepare_char.py --lang-dir $lang_char_dir
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare G"
mkdir -p data/lm
# Train LM on transcripts
if [ ! -f data/lm/3-gram.unpruned.arpa ]; then
python3 ./shared/make_kn_lm.py \
-ngram-order 3 \
-text $lang_char_dir/text_words_segmentation \
-lm data/lm/3-gram.unpruned.arpa
fi
# We assume you have installed kaldilm, if not, please install
# it using: pip install kaldilm
if [ ! -f data/lm/G_3_gram_char.fst.txt ]; then
# It is used in building HLG
python3 -m kaldilm \
--read-symbol-table="$lang_char_dir/words.txt" \
--disambig-symbol='#0' \
--max-order=3 \
data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram_char.fst.txt
fi
if [ ! -f $lang_char_dir/HLG.fst ]; then
./local/prepare_lang_fst.py \
--lang-dir $lang_char_dir \
--ngram-G ./data/lm/G_3_gram_char.fst.txt
fi
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Compile LG & HLG"
./local/compile_hlg.py --lang-dir $lang_char_dir --lm G_3_gram_char
./local/compile_lg.py --lang-dir $lang_char_dir --lm G_3_gram_char
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
log "Stage 8: Generate LM training data"
log "Processing char based data"
out_dir=data/lm_training_char
mkdir -p $out_dir $dl_dir/lm
if [ ! -f $dl_dir/lm/mdcc-train-word.txt ]; then
./local/text2segments.py --input-file $lang_char_dir/train_text \
--output-file $dl_dir/lm/mdcc-train-word.txt
fi
# training words
./local/prepare_char_lm_training_data.py \
--lang-char data/lang_char \
--lm-data $dl_dir/lm/mdcc-train-word.txt \
--lm-archive $out_dir/lm_data.pt
# valid words
if [ ! -f $dl_dir/lm/mdcc-valid-word.txt ]; then
./local/text2segments.py --input-file $lang_char_dir/valid_text \
--output-file $dl_dir/lm/mdcc-valid-word.txt
fi
./local/prepare_char_lm_training_data.py \
--lang-char data/lang_char \
--lm-data $dl_dir/lm/mdcc-valid-word.txt \
--lm-archive $out_dir/lm_data_valid.pt
# test words
if [ ! -f $dl_dir/lm/mdcc-test-word.txt ]; then
./local/text2segments.py --input-file $lang_char_dir/test_text \
--output-file $dl_dir/lm/mdcc-test-word.txt
fi
./local/prepare_char_lm_training_data.py \
--lang-char data/lang_char \
--lm-data $dl_dir/lm/mdcc-test-word.txt \
--lm-archive $out_dir/lm_data_test.pt
fi
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
log "Stage 9: Sort LM training data"
# Sort LM training data by sentence length in descending order
# for ease of training.
#
# Sentence length equals to the number of tokens
# in a sentence.
out_dir=data/lm_training_char
mkdir -p $out_dir
ln -snf ../../../librispeech/ASR/local/sort_lm_training_data.py local/
./local/sort_lm_training_data.py \
--in-lm-data $out_dir/lm_data.pt \
--out-lm-data $out_dir/sorted_lm_data.pt \
--out-statistics $out_dir/statistics.txt
./local/sort_lm_training_data.py \
--in-lm-data $out_dir/lm_data_valid.pt \
--out-lm-data $out_dir/sorted_lm_data-valid.pt \
--out-statistics $out_dir/statistics-valid.txt
./local/sort_lm_training_data.py \
--in-lm-data $out_dir/lm_data_test.pt \
--out-lm-data $out_dir/sorted_lm_data-test.pt \
--out-statistics $out_dir/statistics-test.txt
fi
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
log "Stage 12: Train RNN LM model"
python ../../../icefall/rnn_lm/train.py \
--start-epoch 0 \
--world-size 1 \
--num-epochs 20 \
--use-fp16 0 \
--embedding-dim 512 \
--hidden-dim 512 \
--num-layers 2 \
--batch-size 400 \
--exp-dir rnnlm_char/exp \
--lm-data $out_dir/sorted_lm_data.pt \
--lm-data-valid $out_dir/sorted_lm_data-valid.pt \
--vocab-size 4336 \
--master-port 12345
fi

1
egs/mdcc/ASR/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared/

View File

View File

@ -0,0 +1,382 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
# Copyright 2024 Xiaomi Corporation (Author: Zengrui Jin)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, Optional
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import (
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class MdccAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
def train_dataloaders(
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
transforms.append(
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.info("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=(
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures()
),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
cuts_train = load_manifest_lazy(
self.args.manifest_dir / "mdcc_cuts_train.jsonl.gz"
)
return cuts_train
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get valid cuts")
return load_manifest_lazy(self.args.manifest_dir / "mdcc_cuts_valid.jsonl.gz")
@lru_cache()
def test_cuts(self) -> List[CutSet]:
logging.info("About to get test cuts")
return load_manifest_lazy(self.args.manifest_dir / "mdcc_cuts_test.jsonl.gz")

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py

813
egs/mdcc/ASR/zipformer/decode.py Executable file
View File

@ -0,0 +1,813 @@
#!/usr/bin/env python3
#
# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Mingshuang Luo,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./zipformer/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./zipformer/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method greedy_search
(2) modified beam search
./zipformer/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./zipformer/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(3) fast beam search (trivial_graph)
./zipformer/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./zipformer/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(4) fast beam search (LG)
./zipformer/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(5) fast beam search (nbest oracle WER)
./zipformer/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./zipformer/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
"""
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import MdccAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from lhotse.cut import Cut
from train import add_model_arguments, get_model, get_params
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
make_pad_mask,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_char",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest_oracle
If you use fast_beam_search_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
fast_beam_search, fast_beam_search_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--ilme-scale",
type=float,
default=0.2,
help="""
Used only when --decoding_method is fast_beam_search_LG.
It specifies the scale for the internal language model estimation.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search, fast_beam_search_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search, fast_beam_search_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
if params.causal:
# this seems to cause insertions at the end of the utterance if used with zipformer.
pad_len = 30
feature_lens += pad_len
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, pad_len),
value=LOG_EPS,
)
x, x_lens = model.encoder_embed(feature, feature_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
blank_penalty=params.blank_penalty,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "fast_beam_search_LG":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
blank_penalty=params.blank_penalty,
ilme_scale=params.ilme_scale,
)
for hyp in hyp_tokens:
sentence = "".join([lexicon.word_table[i] for i in hyp])
hyps.append(list(sentence))
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=graph_compiler.texts_to_ids(supervisions["text"]),
nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
blank_penalty=params.blank_penalty,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
blank_penalty=params.blank_penalty,
beam=params.beam_size,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
blank_penalty=params.blank_penalty,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
blank_penalty=params.blank_penalty,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append([lexicon.token_table[idx] for idx in hyp])
key = f"blank_penalty_{params.blank_penalty}"
if params.decoding_method == "greedy_search":
return {"greedy_search_" + key: hyps}
elif "fast_beam_search" in params.decoding_method:
key += f"_beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ilme_scale_{params.ilme_scale}"
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}_" + key: hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
texts = [list("".join(text.split())) for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
lexicon=lexicon,
graph_compiler=graph_compiler,
decoding_graph=decoding_graph,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
this_batch.append((cut_id, ref_text, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
MdccAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"modified_beam_search",
"fast_beam_search",
"fast_beam_search_LG",
"fast_beam_search_nbest_oracle",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.causal:
assert (
"," not in params.chunk_size
), "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"_ilme_scale_{params.ilme_scale}"
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
params.suffix += f"-blank-penalty-{params.blank_penalty}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
logging.info(params)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
if "fast_beam_search" in params.decoding_method:
if "LG" in params.decoding_method:
lexicon = Lexicon(params.lang_dir)
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
mdcc = MdccAsrDataModule(args)
def remove_short_utt(c: Cut):
T = ((c.num_frames - 7) // 2 + 1) // 2
if T <= 0:
logging.warning(
f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}."
)
return T > 0
valid_cuts = mdcc.valid_cuts()
valid_cuts = valid_cuts.filter(remove_short_utt)
valid_dl = mdcc.valid_dataloaders(valid_cuts)
test_cuts = mdcc.test_cuts()
test_cuts = test_cuts.filter(remove_short_utt)
test_dl = mdcc.test_dataloaders(test_cuts)
test_sets = ["valid", "test"]
test_dls = [valid_dl, test_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
lexicon=lexicon,
graph_compiler=graph_compiler,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/decode_stream.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/decoder.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/encoder_interface.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/export-onnx-ctc.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/export-onnx-streaming-ctc.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/export-onnx-streaming.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/export-onnx.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/export.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/joiner.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/onnx_check.py

View File

@ -0,0 +1,286 @@
#!/usr/bin/env python3
#
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Xiaoyu Yang,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads ONNX exported models and uses them to decode the test sets.
"""
import argparse
import logging
import time
from pathlib import Path
from typing import List, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import MdccAsrDataModule
from lhotse.cut import Cut
from onnx_pretrained import OnnxModel, greedy_search
from icefall.utils import setup_logger, store_transcripts, write_error_stats
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder onnx model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder onnx model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner onnx model. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7/exp",
help="The experiment dir",
)
parser.add_argument(
"--tokens",
type=str,
default="data/lang_char/tokens.txt",
help="Path to the tokens.txt",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="Valid values are greedy_search and modified_beam_search",
)
return parser
def decode_one_batch(
model: OnnxModel, token_table: k2.SymbolTable, batch: dict
) -> List[List[str]]:
"""Decode one batch and return the result.
Currently it only greedy_search is supported.
Args:
model:
The neural model.
token_table:
Mapping ids to tokens.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
Returns:
Return the decoded results for each utterance.
"""
feature = batch["inputs"]
assert feature.ndim == 3
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(dtype=torch.int64)
encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens)
hyps = greedy_search(
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
)
hyps = [[token_table[h] for h in hyp] for hyp in hyps]
return hyps
def decode_dataset(
dl: torch.utils.data.DataLoader,
model: nn.Module,
token_table: k2.SymbolTable,
) -> Tuple[List[Tuple[str, List[str], List[str]]], float]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
model:
The neural model.
token_table:
Mapping ids to tokens.
Returns:
- A list of tuples. Each tuple contains three elements:
- cut_id,
- reference transcript,
- predicted result.
- The total duration (in seconds) of the dataset.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
log_interval = 10
total_duration = 0
results = []
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]])
hyps = decode_one_batch(model=model, token_table=token_table, batch=batch)
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = list(ref_text)
this_batch.append((cut_id, ref_words, hyp_words))
results.extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results, total_duration
def save_results(
res_dir: Path,
test_set_name: str,
results: List[Tuple[str, List[str], List[str]]],
):
recog_path = res_dir / f"recogs-{test_set_name}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = res_dir / f"errs-{test_set_name}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True)
logging.info("Wrote detailed error stats to {}".format(errs_filename))
errs_info = res_dir / f"wer-summary-{test_set_name}.txt"
with open(errs_info, "w") as f:
print("WER", file=f)
print(wer, file=f)
s = "\nFor {}, WER is {}:\n".format(test_set_name, wer)
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
MdccAsrDataModule.add_arguments(parser)
args = parser.parse_args()
assert (
args.decoding_method == "greedy_search"
), "Only supports greedy_search currently."
res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}"
setup_logger(f"{res_dir}/log-decode")
logging.info("Decoding started")
device = torch.device("cpu")
logging.info(f"Device: {device}")
token_table = k2.SymbolTable.from_file(args.tokens)
assert token_table[0] == "<blk>"
logging.info(vars(args))
logging.info("About to create model")
model = OnnxModel(
encoder_model_filename=args.encoder_model_filename,
decoder_model_filename=args.decoder_model_filename,
joiner_model_filename=args.joiner_model_filename,
)
# we need cut ids to display recognition results.
args.return_cuts = True
mdcc = MdccAsrDataModule(args)
def remove_short_utt(c: Cut):
T = ((c.num_frames - 7) // 2 + 1) // 2
if T <= 0:
logging.warning(
f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}."
)
return T > 0
valid_cuts = mdcc.valid_cuts()
valid_cuts = valid_cuts.filter(remove_short_utt)
valid_dl = mdcc.valid_dataloaders(valid_cuts)
test_cuts = mdcc.test_net_cuts()
test_cuts = test_cuts.filter(remove_short_utt)
test_dl = mdcc.test_dataloaders(test_cuts)
test_sets = ["valid", "test"]
test_dl = [valid_dl, test_dl]
for test_set, test_dl in zip(test_sets, test_dl):
start_time = time.time()
results, total_duration = decode_dataset(
dl=test_dl, model=model, token_table=token_table
)
end_time = time.time()
elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / total_duration
logging.info(f"Elapsed time: {elapsed_seconds:.3f} s")
logging.info(f"Wave duration: {total_duration:.3f} s")
logging.info(
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
)
save_results(res_dir=res_dir, test_set_name=test_set, results=results)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/optim.py

Some files were not shown because too many files have changed in this diff Show More