mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Merge branch 'k2-fsa:master' into dev/tts/vctk/tokenizer
This commit is contained in:
commit
7ea100a26a
1
.github/scripts/.gitignore
vendored
Normal file
1
.github/scripts/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
piper_phonemize.html
|
29
.github/scripts/generate-piper-phonemize-page.py
vendored
Executable file
29
.github/scripts/generate-piper-phonemize-page.py
vendored
Executable file
@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
|
||||
def main():
|
||||
prefix = (
|
||||
"https://github.com/csukuangfj/piper-phonemize/releases/download/2023.12.5/"
|
||||
)
|
||||
files = [
|
||||
"piper_phonemize-1.2.0-cp310-cp310-macosx_10_14_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp311-cp311-macosx_10_14_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp312-cp312-macosx_10_14_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp37-cp37m-macosx_10_14_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp38-cp38-macosx_10_14_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp39-cp39-macosx_10_14_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
|
||||
]
|
||||
with open("piper_phonemize.html", "w") as f:
|
||||
for file in files:
|
||||
url = prefix + file
|
||||
f.write(f'<a href="{url}">{file}</a><br/>\n')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
6
.github/scripts/librispeech/ASR/run.sh
vendored
6
.github/scripts/librispeech/ASR/run.sh
vendored
@ -15,9 +15,9 @@ function prepare_data() {
|
||||
# cause OOM error for CI later.
|
||||
mkdir -p download/lm
|
||||
pushd download/lm
|
||||
wget -q http://www.openslr.org/resources/11/librispeech-vocab.txt
|
||||
wget -q http://www.openslr.org/resources/11/librispeech-lexicon.txt
|
||||
wget -q http://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz
|
||||
wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-lm-norm.txt.gz
|
||||
wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-lexicon.txt
|
||||
wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-vocab.txt
|
||||
ls -lh
|
||||
gunzip librispeech-lm-norm.txt.gz
|
||||
|
||||
|
157
.github/scripts/ljspeech/TTS/run.sh
vendored
Executable file
157
.github/scripts/ljspeech/TTS/run.sh
vendored
Executable file
@ -0,0 +1,157 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -ex
|
||||
|
||||
python3 -m pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html
|
||||
python3 -m pip install espnet_tts_frontend
|
||||
python3 -m pip install numba
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
cd egs/ljspeech/TTS
|
||||
|
||||
sed -i.bak s/600/8/g ./prepare.sh
|
||||
sed -i.bak s/"first 100"/"first 3"/g ./prepare.sh
|
||||
sed -i.bak s/500/5/g ./prepare.sh
|
||||
git diff
|
||||
|
||||
function prepare_data() {
|
||||
# We have created a subset of the data for testing
|
||||
#
|
||||
mkdir download
|
||||
pushd download
|
||||
wget -q https://huggingface.co/csukuangfj/ljspeech-subset-for-ci-test/resolve/main/LJSpeech-1.1.tar.bz2
|
||||
tar xvf LJSpeech-1.1.tar.bz2
|
||||
popd
|
||||
|
||||
./prepare.sh
|
||||
tree .
|
||||
}
|
||||
|
||||
function train() {
|
||||
pushd ./vits
|
||||
sed -i.bak s/200/3/g ./train.py
|
||||
git diff .
|
||||
popd
|
||||
|
||||
for t in low medium high; do
|
||||
./vits/train.py \
|
||||
--exp-dir vits/exp-$t \
|
||||
--model-type $t \
|
||||
--num-epochs 1 \
|
||||
--save-every-n 1 \
|
||||
--num-buckets 2 \
|
||||
--tokens data/tokens.txt \
|
||||
--max-duration 20
|
||||
|
||||
ls -lh vits/exp-$t
|
||||
done
|
||||
}
|
||||
|
||||
function infer() {
|
||||
for t in low medium high; do
|
||||
./vits/infer.py \
|
||||
--num-buckets 2 \
|
||||
--model-type $t \
|
||||
--epoch 1 \
|
||||
--exp-dir ./vits/exp-$t \
|
||||
--tokens data/tokens.txt \
|
||||
--max-duration 20
|
||||
done
|
||||
}
|
||||
|
||||
function export_onnx() {
|
||||
for t in low medium high; do
|
||||
./vits/export-onnx.py \
|
||||
--model-type $t \
|
||||
--epoch 1 \
|
||||
--exp-dir ./vits/exp-$t \
|
||||
--tokens data/tokens.txt
|
||||
|
||||
ls -lh vits/exp-$t/
|
||||
done
|
||||
}
|
||||
|
||||
function test_medium() {
|
||||
git clone https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-medium-2024-03-12
|
||||
|
||||
./vits/export-onnx.py \
|
||||
--model-type medium \
|
||||
--epoch 820 \
|
||||
--exp-dir ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp \
|
||||
--tokens ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt
|
||||
|
||||
ls -lh ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp
|
||||
|
||||
./vits/test_onnx.py \
|
||||
--model-filename ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp/vits-epoch-820.onnx \
|
||||
--tokens ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt \
|
||||
--output-filename /icefall/test-medium.wav
|
||||
|
||||
ls -lh /icefall/test-medium.wav
|
||||
|
||||
d=/icefall/vits-icefall-en_US-ljspeech-medium
|
||||
mkdir $d
|
||||
cp -v ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt $d/
|
||||
cp -v ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp/vits-epoch-820.onnx $d/model.onnx
|
||||
|
||||
rm -rf icefall-tts-ljspeech-vits-medium-2024-03-12
|
||||
|
||||
pushd $d
|
||||
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
|
||||
tar xf espeak-ng-data.tar.bz2
|
||||
rm espeak-ng-data.tar.bz2
|
||||
cd ..
|
||||
tar cjf vits-icefall-en_US-ljspeech-medium.tar.bz2 vits-icefall-en_US-ljspeech-medium
|
||||
rm -rf vits-icefall-en_US-ljspeech-medium
|
||||
ls -lh *.tar.bz2
|
||||
popd
|
||||
}
|
||||
|
||||
function test_low() {
|
||||
git clone https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-low-2024-03-12
|
||||
|
||||
./vits/export-onnx.py \
|
||||
--model-type low \
|
||||
--epoch 1600 \
|
||||
--exp-dir ./icefall-tts-ljspeech-vits-low-2024-03-12/exp \
|
||||
--tokens ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt
|
||||
|
||||
ls -lh ./icefall-tts-ljspeech-vits-low-2024-03-12/exp
|
||||
|
||||
./vits/test_onnx.py \
|
||||
--model-filename ./icefall-tts-ljspeech-vits-low-2024-03-12/exp/vits-epoch-1600.onnx \
|
||||
--tokens ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt \
|
||||
--output-filename /icefall/test-low.wav
|
||||
|
||||
ls -lh /icefall/test-low.wav
|
||||
|
||||
d=/icefall/vits-icefall-en_US-ljspeech-low
|
||||
mkdir $d
|
||||
cp -v ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt $d/
|
||||
cp -v ./icefall-tts-ljspeech-vits-low-2024-03-12/exp/vits-epoch-1600.onnx $d/model.onnx
|
||||
|
||||
rm -rf icefall-tts-ljspeech-vits-low-2024-03-12
|
||||
|
||||
pushd $d
|
||||
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
|
||||
tar xf espeak-ng-data.tar.bz2
|
||||
rm espeak-ng-data.tar.bz2
|
||||
cd ..
|
||||
tar cjf vits-icefall-en_US-ljspeech-low.tar.bz2 vits-icefall-en_US-ljspeech-low
|
||||
rm -rf vits-icefall-en_US-ljspeech-low
|
||||
ls -lh *.tar.bz2
|
||||
popd
|
||||
}
|
||||
|
||||
prepare_data
|
||||
train
|
||||
infer
|
||||
export_onnx
|
||||
rm -rf vits/exp-{low,medium,high}
|
||||
test_medium
|
||||
test_low
|
3
.github/workflows/build-doc.yml
vendored
3
.github/workflows/build-doc.yml
vendored
@ -56,11 +56,14 @@ jobs:
|
||||
- name: Build doc
|
||||
shell: bash
|
||||
run: |
|
||||
.github/scripts/generate-piper-phonemize-page.py
|
||||
cd docs
|
||||
python3 -m pip install -r ./requirements.txt
|
||||
make html
|
||||
touch build/html/.nojekyll
|
||||
|
||||
cp -v ../piper_phonemize.html ./build/html/
|
||||
|
||||
- name: Deploy
|
||||
uses: peaceiris/actions-gh-pages@v3
|
||||
with:
|
||||
|
102
.github/workflows/ljspeech.yml
vendored
Normal file
102
.github/workflows/ljspeech.yml
vendored
Normal file
@ -0,0 +1,102 @@
|
||||
name: ljspeech
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ljspeech-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
generate_build_matrix:
|
||||
if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa'
|
||||
# see https://github.com/pytorch/pytorch/pull/50633
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Generating build matrix
|
||||
id: set-matrix
|
||||
run: |
|
||||
# outputting for debugging purposes
|
||||
python ./.github/scripts/docker/generate_build_matrix.py
|
||||
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py)
|
||||
echo "::set-output name=matrix::${MATRIX}"
|
||||
|
||||
ljspeech:
|
||||
needs: generate_build_matrix
|
||||
name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }}
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Free space
|
||||
shell: bash
|
||||
run: |
|
||||
ls -lh
|
||||
df -h
|
||||
rm -rf /opt/hostedtoolcache
|
||||
df -h
|
||||
echo "pwd: $PWD"
|
||||
echo "github.workspace ${{ github.workspace }}"
|
||||
|
||||
- name: Run tests
|
||||
uses: addnab/docker-run-action@v3
|
||||
with:
|
||||
image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }}
|
||||
options: |
|
||||
--volume ${{ github.workspace }}/:/icefall
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=/icefall:$PYTHONPATH
|
||||
cd /icefall
|
||||
git config --global --add safe.directory /icefall
|
||||
|
||||
.github/scripts/ljspeech/TTS/run.sh
|
||||
|
||||
- name: display files
|
||||
shell: bash
|
||||
run: |
|
||||
ls -lh
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0'
|
||||
with:
|
||||
name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }}
|
||||
path: ./*.wav
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0'
|
||||
with:
|
||||
name: generated-models-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}
|
||||
path: ./*.wav
|
||||
|
||||
- name: Release exported onnx models
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0'
|
||||
uses: svenstaro/upload-release-action@v2
|
||||
with:
|
||||
file_glob: true
|
||||
overwrite: true
|
||||
file: vits-icefall-*.tar.bz2
|
||||
repo_name: k2-fsa/sherpa-onnx
|
||||
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
|
||||
tag: tts-models
|
||||
|
225
docs/source/recipes/Finetune/adapter/finetune_adapter.rst
Normal file
225
docs/source/recipes/Finetune/adapter/finetune_adapter.rst
Normal file
@ -0,0 +1,225 @@
|
||||
Finetune from a pre-trained Zipformer model with adapters
|
||||
=========================================================
|
||||
|
||||
This tutorial shows you how to fine-tune a pre-trained **Zipformer**
|
||||
transducer model on a new dataset with adapters.
|
||||
Adapters are compact and efficient module that can be integrated into a pre-trained model
|
||||
to improve the model's performance on a new domain. Adapters are injected
|
||||
between different modules in the well-trained neural network. During training, only the parameters
|
||||
in the adapters will be updated. It achieves competitive performance
|
||||
while requiring much less GPU memory than full fine-tuning. For more details about adapters,
|
||||
please refer to the original `paper <https://arxiv.org/pdf/1902.00751.pdf#/>`_ for more details.
|
||||
|
||||
.. HINT::
|
||||
|
||||
We assume you have read the page :ref:`install icefall` and have setup
|
||||
the environment for ``icefall``.
|
||||
|
||||
.. HINT::
|
||||
|
||||
We recommend you to use a GPU or several GPUs to run this recipe
|
||||
|
||||
For illustration purpose, we fine-tune the Zipformer transducer model
|
||||
pre-trained on `LibriSpeech`_ on the small subset of `GigaSpeech`_. You could use your
|
||||
own data for fine-tuning if you create a manifest for your new dataset.
|
||||
|
||||
Data preparation
|
||||
----------------
|
||||
|
||||
Please follow the instructions in the `GigaSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/gigaspeech/ASR>`_
|
||||
to prepare the fine-tune data used in this tutorial. We only require the small subset in GigaSpeech for this tutorial.
|
||||
|
||||
|
||||
Model preparation
|
||||
-----------------
|
||||
|
||||
We are using the Zipformer model trained on full LibriSpeech (960 hours) as the intialization. The
|
||||
checkpoint of the model can be downloaded via the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
$ cd icefall-asr-librispeech-zipformer-2023-05-15/exp
|
||||
$ git lfs pull --include "pretrained.pt"
|
||||
$ ln -s pretrained.pt epoch-99.pt
|
||||
$ cd ../data/lang_bpe_500
|
||||
$ git lfs pull --include bpe.model
|
||||
$ cd ../../..
|
||||
|
||||
Before fine-tuning, let's test the model's WER on the new domain. The following command performs
|
||||
decoding on the GigaSpeech test sets:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
./zipformer/decode_gigaspeech.py \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--exp-dir icefall-asr-librispeech-zipformer-2023-05-15/exp \
|
||||
--use-averaged-model 0 \
|
||||
--max-duration 1000 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
You should see the following numbers:
|
||||
|
||||
.. code-block::
|
||||
|
||||
For dev, WER of different settings are:
|
||||
greedy_search 20.06 best for dev
|
||||
|
||||
For test, WER of different settings are:
|
||||
greedy_search 19.27 best for test
|
||||
|
||||
|
||||
Fine-tune with adapter
|
||||
----------------------
|
||||
|
||||
We insert 4 adapters with residual connection in each ``Zipformer2EncoderLayer``.
|
||||
The original model parameters remain untouched during training and only the parameters of
|
||||
the adapters are updated. The following command starts a fine-tuning experiment with adapters:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ do_finetune=1
|
||||
$ use_adapters=1
|
||||
$ adapter_dim=8
|
||||
|
||||
$ ./zipformer_adapter/train.py \
|
||||
--world-size 2 \
|
||||
--num-epochs 20 \
|
||||
--start-epoch 1 \
|
||||
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
|
||||
--use-fp16 1 \
|
||||
--base-lr 0.045 \
|
||||
--use-adapters $use_adapters --adapter-dim $adapter_dim \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--do-finetune $do_finetune \
|
||||
--master-port 13022 \
|
||||
--finetune-ckpt icefall-asr-librispeech-zipformer-2023-05-15/exp/pretrained.pt \
|
||||
--max-duration 1000
|
||||
|
||||
The following arguments are related to fine-tuning:
|
||||
|
||||
- ``--do-finetune``
|
||||
If True, do fine-tuning by initializing the model from a pre-trained checkpoint.
|
||||
**Note that if you want to resume your fine-tuning experiment from certain epochs, you
|
||||
need to set this to False.**
|
||||
|
||||
- ``use-adapters``
|
||||
If adapters are used during fine-tuning.
|
||||
|
||||
- ``--adapter-dim``
|
||||
The bottleneck dimension of the adapter module. Typically a small number.
|
||||
|
||||
You should notice that in the training log, the total number of trainale parameters is shown:
|
||||
|
||||
.. code-block::
|
||||
|
||||
2024-02-22 21:22:03,808 INFO [train.py:1277] A total of 761344 trainable parameters (1.148% of the whole model)
|
||||
|
||||
The trainable parameters only makes up 1.15% of the entire model parameters, so the training will be much faster
|
||||
and requires less memory than full fine-tuning.
|
||||
|
||||
|
||||
Decoding
|
||||
--------
|
||||
|
||||
After training, let's test the WERs. To test the WERs on the GigaSpeech set,
|
||||
you can execute the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ epoch=20
|
||||
$ avg=10
|
||||
$ use_adapters=1
|
||||
$ adapter_dim=8
|
||||
|
||||
% ./zipformer/decode.py \
|
||||
--epoch $epoch \
|
||||
--avg $avg \
|
||||
--use-averaged-model 1 \
|
||||
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
|
||||
--max-duration 600 \
|
||||
--use-adapters $use_adapters \
|
||||
--adapter-dim $adapter_dim \
|
||||
--decoding-method greedy_search
|
||||
|
||||
You should see the following numbers:
|
||||
|
||||
.. code-block::
|
||||
|
||||
For dev, WER of different settings are:
|
||||
greedy_search 15.44 best for dev
|
||||
|
||||
For test, WER of different settings are:
|
||||
greedy_search 15.42 best for test
|
||||
|
||||
|
||||
The WER on test set is improved from 19.27 to 15.42, demonstrating the effectiveness of adapters.
|
||||
|
||||
The same model can be used to perform decoding on LibriSpeech test sets. You can deactivate the adapters
|
||||
to keep the same performance of the original model:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ epoch=20
|
||||
$ avg=1
|
||||
$ use_adapters=0
|
||||
$ adapter_dim=8
|
||||
|
||||
% ./zipformer/decode.py \
|
||||
--epoch $epoch \
|
||||
--avg $avg \
|
||||
--use-averaged-model 1 \
|
||||
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
|
||||
--max-duration 600 \
|
||||
--use-adapters $use_adapters \
|
||||
--adapter-dim $adapter_dim \
|
||||
--decoding-method greedy_search
|
||||
|
||||
|
||||
.. code-block::
|
||||
|
||||
For dev, WER of different settings are:
|
||||
greedy_search 2.23 best for test-clean
|
||||
|
||||
For test, WER of different settings are:
|
||||
greedy_search 4.96 best for test-other
|
||||
|
||||
The numbers are the same as reported in `icefall <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md#normal-scaled-model-number-of-model-parameters-65549011-ie-6555-m>`_. So adapter-based
|
||||
fine-tuning is also very flexible as the same model can be used for decoding on the original and target domain.
|
||||
|
||||
|
||||
Export the model
|
||||
----------------
|
||||
|
||||
After training, the model can be exported to ``onnx`` format easily using the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ use_adapters=1
|
||||
$ adapter_dim=16
|
||||
|
||||
$ ./zipformer_adapter/export-onnx.py \
|
||||
--tokens icefall-asr-librispeech-zipformer-2023-05-15/data/lang_bpe_500/tokens.txt \
|
||||
--use-averaged-model 1 \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
|
||||
--use-adapters $use_adapters \
|
||||
--adapter-dim $adapter_dim \
|
||||
--num-encoder-layers "2,2,3,4,3,2" \
|
||||
--downsampling-factor "1,2,4,8,4,2" \
|
||||
--feedforward-dim "512,768,1024,1536,1024,768" \
|
||||
--num-heads "4,4,4,8,4,4" \
|
||||
--encoder-dim "192,256,384,512,384,256" \
|
||||
--query-head-dim 32 \
|
||||
--value-head-dim 12 \
|
||||
--pos-head-dim 4 \
|
||||
--pos-dim 48 \
|
||||
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
||||
--cnn-module-kernel "31,31,15,15,15,31" \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512 \
|
||||
--causal False \
|
||||
--chunk-size "16,32,64,-1" \
|
||||
--left-context-frames "64,128,256,-1"
|
@ -13,3 +13,4 @@ data to improve the performance on new domains.
|
||||
:caption: Table of Contents
|
||||
|
||||
from_supervised/finetune_zipformer
|
||||
adapter/finetune_adapter
|
||||
|
@ -13,6 +13,14 @@ with the `LJSpeech <https://keithito.com/LJ-Speech-Dataset/>`_ dataset.
|
||||
The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech <https://arxiv.org/pdf/2106.06103.pdf>`_
|
||||
|
||||
|
||||
Install extra dependencies
|
||||
--------------------------
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html
|
||||
pip install numba espnet_tts_frontend
|
||||
|
||||
Data preparation
|
||||
----------------
|
||||
|
||||
@ -56,7 +64,8 @@ Training
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir vits/exp \
|
||||
--tokens data/tokens.txt
|
||||
--tokens data/tokens.txt \
|
||||
--model-type high \
|
||||
--max-duration 500
|
||||
|
||||
.. note::
|
||||
@ -64,6 +73,11 @@ Training
|
||||
You can adjust the hyper-parameters to control the size of the VITS model and
|
||||
the training configurations. For more details, please run ``./vits/train.py --help``.
|
||||
|
||||
.. warning::
|
||||
|
||||
If you want a model that runs faster on CPU, please use ``--model-type low``
|
||||
or ``--model-type medium``.
|
||||
|
||||
.. note::
|
||||
|
||||
The training can take a long time (usually a couple of days).
|
||||
@ -95,8 +109,8 @@ training part first. It will save the ground-truth and generated wavs to the dir
|
||||
Export models
|
||||
-------------
|
||||
|
||||
Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``:
|
||||
``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``.
|
||||
Currently we only support ONNX model exporting. It will generate one file in the given ``exp-dir``:
|
||||
``vits-epoch-*.onnx``.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
@ -120,4 +134,68 @@ Download pretrained models
|
||||
If you don't want to train from scratch, you can download the pretrained models
|
||||
by visiting the following link:
|
||||
|
||||
- `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2024-02-28>`_
|
||||
- ``--model-type=high``: `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2024-02-28>`_
|
||||
- ``--model-type=medium``: `<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-medium-2024-03-12>`_
|
||||
- ``--model-type=low``: `<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-low-2024-03-12>`_
|
||||
|
||||
Usage in sherpa-onnx
|
||||
--------------------
|
||||
|
||||
The following describes how to test the exported ONNX model in `sherpa-onnx`_.
|
||||
|
||||
.. hint::
|
||||
|
||||
`sherpa-onnx`_ supports different programming languages, e.g., C++, C, Python,
|
||||
Kotlin, Java, Swift, Go, C#, etc. It also supports Android and iOS.
|
||||
|
||||
We only describe how to use pre-built binaries from `sherpa-onnx`_ below.
|
||||
Please refer to `<https://k2-fsa.github.io/sherpa/onnx/>`_
|
||||
for more documentation.
|
||||
|
||||
Install sherpa-onnx
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install sherpa-onnx
|
||||
|
||||
To check that you have installed `sherpa-onnx`_ successfully, please run:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
which sherpa-onnx-offline-tts
|
||||
sherpa-onnx-offline-tts --help
|
||||
|
||||
Download lexicon files
|
||||
^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cd /tmp
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
|
||||
tar xf espeak-ng-data.tar.bz2
|
||||
|
||||
Run sherpa-onnx
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cd egs/ljspeech/TTS
|
||||
|
||||
sherpa-onnx-offline-tts \
|
||||
--vits-model=vits/exp/vits-epoch-1000.onnx \
|
||||
--vits-tokens=data/tokens.txt \
|
||||
--vits-data-dir=/tmp/espeak-ng-data \
|
||||
--num-threads=1 \
|
||||
--output-filename=./high.wav \
|
||||
"Ask not what your country can do for you; ask what you can do for your country."
|
||||
|
||||
.. hint::
|
||||
|
||||
You can also use ``sherpa-onnx-offline-tts-play`` to play the audio
|
||||
as it is generating.
|
||||
|
||||
You should get a file ``high.wav`` after running the above command.
|
||||
|
||||
Congratulations! You have successfully trained and exported a text-to-speech
|
||||
model and run it with `sherpa-onnx`_.
|
||||
|
@ -19,7 +19,9 @@ The following table lists the differences among them.
|
||||
| `transducer_stateless_modified` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` |
|
||||
| `transducer_stateless_modified-2` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` + extra data |
|
||||
| `pruned_transducer_stateless3` | Conformer (reworked) | Embedding + Conv1d | pruned RNN-T + reworked model with random combiner + using aidatatang_20zh as extra data|
|
||||
| `pruned_transducer_stateless7` | Zipformer | Embedding | pruned RNN-T + zipformer encoder + stateless decoder with context-size 1 |
|
||||
| `pruned_transducer_stateless7` | Zipformer | Embedding | pruned RNN-T + zipformer encoder + stateless decoder with context-size set to 1 |
|
||||
| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe with context-size set to 1 |
|
||||
|
||||
|
||||
The decoder in `transducer_stateless` is modified from the paper
|
||||
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
|
||||
|
@ -75,7 +75,7 @@ It's reworked Zipformer with Pruned RNNT loss, trained with Byte-level BPE, `voc
|
||||
| fast beam search | 4.43 | 4.17 | --epoch 40 --avg 10 |
|
||||
|
||||
```bash
|
||||
./prepare.sh
|
||||
./prepare.sh
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1"
|
||||
|
||||
|
@ -360,7 +360,7 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
||||
log "Stage 11: Train RNN LM model"
|
||||
log "Stage 12: Train RNN LM model"
|
||||
python ../../../icefall/rnn_lm/train.py \
|
||||
--start-epoch 0 \
|
||||
--world-size 1 \
|
||||
|
@ -89,6 +89,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -881,9 +882,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error()
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
||||
|
@ -85,6 +85,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
@ -878,9 +879,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
||||
|
@ -78,6 +78,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -871,9 +872,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -250,7 +250,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=1,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
@ -78,6 +78,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
@ -882,9 +883,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -78,6 +78,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
@ -881,9 +882,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -86,6 +86,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
@ -985,9 +986,7 @@ def train_one_epoch(
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
|
@ -83,6 +83,7 @@ from icefall.checkpoint import (
|
||||
update_averaged_model,
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -570,9 +571,7 @@ def train_one_epoch(
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
|
@ -1,6 +1,6 @@
|
||||
## Results
|
||||
|
||||
### Aishell2 char-based training results
|
||||
### Aishell2 char-based training results
|
||||
|
||||
#### Pruned transducer stateless 5
|
||||
|
||||
|
@ -29,7 +29,14 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
@ -42,10 +49,12 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
def compute_fbank_aishell2(
|
||||
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
num_jobs = min(8, os.cpu_count())
|
||||
|
||||
dataset_parts = (
|
||||
"train",
|
||||
@ -68,8 +77,12 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
list(manifests.keys()),
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
@ -82,7 +95,7 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition and perturb_speed:
|
||||
logging.info(f"Doing speed perturb")
|
||||
logging.info("Doing speed perturb")
|
||||
cut_set = (
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
@ -111,7 +124,12 @@ def get_args():
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-fbank",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -122,5 +140,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_aishell2(
|
||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
perturb_speed=args.perturb_speed,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
@ -108,6 +108,16 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
whisper_mel_bins=80
|
||||
if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then
|
||||
log "Stage 30: Compute whisper fbank for aishell2"
|
||||
if [ ! -f data/fbank/.aishell2.whisper.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_aishell2.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||
touch data/fbank/.aishell2.whisper.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Compute fbank for musan"
|
||||
if [ ! -f data/fbank/.msuan.done ]; then
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
This recipe contains some various ASR models trained with Aishell4 (including S, M and L three subsets).
|
||||
|
||||
The AISHELL-4 is a sizable real-recorded Mandarin speech dataset collected by 8-channel circular microphone array for speech processing in conference scenarios. The dataset consists of 211 recorded meeting sessions, each containing 4 to 8 speakers, with a total length of 120 hours. This dataset aims to bridge the advanced research on multi-speaker processing and the practical application scenario in three aspects. With real recorded meetings, AISHELL-4 provides realistic acoustics and rich natural speech characteristics in conversation such as short pause, speech overlap, quick speaker turn, noise, etc. Meanwhile, the accurate transcription and speaker voice activity are provided for each meeting in AISHELL-4. This allows the researchers to explore different aspects in meeting processing, ranging from individual tasks such as speech front-end processing, speech recognition and speaker diarization, to multi-modality modeling and joint optimization of relevant tasks.
|
||||
The AISHELL-4 is a sizable real-recorded Mandarin speech dataset collected by 8-channel circular microphone array for speech processing in conference scenarios. The dataset consists of 211 recorded meeting sessions, each containing 4 to 8 speakers, with a total length of 120 hours. This dataset aims to bridge the advanced research on multi-speaker processing and the practical application scenario in three aspects. With real recorded meetings, AISHELL-4 provides realistic acoustics and rich natural speech characteristics in conversation such as short pause, speech overlap, quick speaker turn, noise, etc. Meanwhile, the accurate transcription and speaker voice activity are provided for each meeting in AISHELL-4. This allows the researchers to explore different aspects in meeting processing, ranging from individual tasks such as speech front-end processing, speech recognition and speaker diarization, to multi-modality modeling and joint optimization of relevant tasks.
|
||||
|
||||
(From [Open Speech and Language Resources](https://www.openslr.org/111/))
|
||||
|
||||
|
@ -29,7 +29,14 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
@ -42,10 +49,12 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
def compute_fbank_aishell4(
|
||||
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests/aishell4")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
num_jobs = min(8, os.cpu_count())
|
||||
|
||||
dataset_parts = (
|
||||
"train_S",
|
||||
@ -70,7 +79,12 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
@ -84,7 +98,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition and perturb_speed:
|
||||
logging.info(f"Doing speed perturb")
|
||||
logging.info("Doing speed perturb")
|
||||
cut_set = (
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
@ -95,7 +109,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
# when an executor is specified, make more partitions
|
||||
num_jobs=num_jobs if ex is None else 80,
|
||||
executor=ex,
|
||||
storage_type=ChunkedLilcomHdf5Writer,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
|
||||
logging.info("About splitting cuts into smaller chunks")
|
||||
@ -121,7 +135,12 @@ def get_args():
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-fbank",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -132,5 +151,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_aishell4(
|
||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
perturb_speed=args.perturb_speed,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
@ -6,7 +6,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
set -eou pipefail
|
||||
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
stop_stage=7
|
||||
perturb_speed=true
|
||||
|
||||
|
||||
@ -76,11 +76,21 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Process aishell4"
|
||||
log "Stage 2: Compute fbank for aishell4"
|
||||
if [ ! -f data/fbank/aishell4/.fbank.done ]; then
|
||||
mkdir -p data/fbank/aishell4
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed}
|
||||
touch data/fbank/aishell4/.fbank.done
|
||||
touch data/fbank/.fbank.done
|
||||
fi
|
||||
fi
|
||||
|
||||
whisper_mel_bins=80
|
||||
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
|
||||
log "Stage 20: Compute whisper fbank for aishell4"
|
||||
if [ ! -f data/fbank/aishell4/.fbank.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||
touch data/fbank/.fbank.done
|
||||
fi
|
||||
fi
|
||||
|
||||
@ -106,16 +116,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Compute fbank for aishell4"
|
||||
if [ ! -f data/fbank/.aishell4.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed}
|
||||
touch data/fbank/.aishell4.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Prepare char based lang"
|
||||
log "Stage 5: Prepare char based lang"
|
||||
lang_char_dir=data/lang_char
|
||||
mkdir -p $lang_char_dir
|
||||
|
||||
|
@ -29,7 +29,14 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
@ -42,10 +49,12 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
def compute_fbank_alimeeting(
|
||||
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests/alimeeting")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
num_jobs = min(8, os.cpu_count())
|
||||
|
||||
dataset_parts = (
|
||||
"train",
|
||||
@ -53,7 +62,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
|
||||
"test",
|
||||
)
|
||||
|
||||
prefix = "alimeeting"
|
||||
prefix = "alimeeting-far"
|
||||
suffix = "jsonl.gz"
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=dataset_parts,
|
||||
@ -70,7 +79,12 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
@ -83,7 +97,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition and perturb_speed:
|
||||
logging.info(f"Doing speed perturb")
|
||||
logging.info("Doing speed perturb")
|
||||
cut_set = (
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
@ -121,7 +135,12 @@ def get_args():
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-fbank",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use the Whisper Fbank feature extractor. Default: False.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -132,5 +151,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_alimeeting(
|
||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
perturb_speed=args.perturb_speed,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
@ -6,7 +6,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
set -eou pipefail
|
||||
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
stop_stage=7
|
||||
perturb_speed=true
|
||||
|
||||
# We assume dl_dir (download dir) contains the following
|
||||
@ -15,7 +15,7 @@ perturb_speed=true
|
||||
#
|
||||
# - $dl_dir/alimeeting
|
||||
# This directory contains the following files downloaded from
|
||||
# https://openslr.org/62/
|
||||
# https://openslr.org/119/
|
||||
#
|
||||
# - Train_Ali_far.tar.gz
|
||||
# - Train_Ali_near.tar.gz
|
||||
@ -66,10 +66,21 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Process alimeeting"
|
||||
if [ ! -f data/fbank/alimeeting/.fbank.done ]; then
|
||||
mkdir -p data/fbank/alimeeting
|
||||
log "Stage 2: compute fbank for alimeeting"
|
||||
if [ ! -f data/fbank/.fbank.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed}
|
||||
touch data/fbank/.fbank.done
|
||||
fi
|
||||
fi
|
||||
|
||||
whisper_mel_bins=80
|
||||
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
|
||||
log "Stage 20: compute whisper fbank for alimeeting"
|
||||
if [ ! -f data/fbank/.fbank.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||
touch data/fbank/.fbank.done
|
||||
fi
|
||||
fi
|
||||
|
||||
@ -95,16 +106,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Compute fbank for alimeeting"
|
||||
if [ ! -f data/fbank/.alimeeting.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_alimeeting.py --perturb-speed True
|
||||
touch data/fbank/.alimeeting.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Prepare char based lang"
|
||||
log "Stage 5: Prepare char based lang"
|
||||
lang_char_dir=data/lang_char
|
||||
mkdir -p $lang_char_dir
|
||||
|
||||
|
@ -12,7 +12,7 @@ use_gss=true # Use GSS-based enhancement with MDM setting
|
||||
#
|
||||
# - $dl_dir/alimeeting
|
||||
# This directory contains the following files downloaded from
|
||||
# https://openslr.org/62/
|
||||
# https://openslr.org/119/
|
||||
#
|
||||
# - Train_Ali_far.tar.gz
|
||||
# - Train_Ali_near.tar.gz
|
||||
|
@ -70,6 +70,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
@ -851,9 +852,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -69,6 +69,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
@ -842,9 +843,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -75,6 +75,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -1138,9 +1139,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -75,6 +75,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -1129,9 +1130,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -1 +0,0 @@
|
||||
../../../librispeech/ASR/local/compile_hlg.py
|
168
egs/commonvoice/ASR/local/compile_hlg.py
Executable file
168
egs/commonvoice/ASR/local/compile_hlg.py
Executable file
@ -0,0 +1,168 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Zengrui Jin,)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This script takes as input lang_dir and generates HLG from
|
||||
|
||||
- H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
|
||||
- L, the lexicon, built from lang_dir/L_disambig.pt
|
||||
|
||||
Caution: We use a lexicon that contains disambiguation symbols
|
||||
|
||||
- G, the LM, built from data/lm/G_n_gram.fst.txt
|
||||
|
||||
The generated HLG is saved in $lang_dir/HLG.pt
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lm",
|
||||
type=str,
|
||||
default="G_3_gram",
|
||||
help="""Stem name for LM used in HLG compiling.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
|
||||
lm:
|
||||
The language stem base name.
|
||||
|
||||
Return:
|
||||
An FSA representing HLG.
|
||||
"""
|
||||
lexicon = Lexicon(lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
|
||||
H = k2.ctc_topo(max_token_id)
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
|
||||
if Path(f"{lang_dir}/lm/{lm}.pt").is_file():
|
||||
logging.info(f"Loading pre-compiled {lm}")
|
||||
d = torch.load(f"{lang_dir}/lm/{lm}.pt")
|
||||
G = k2.Fsa.from_dict(d)
|
||||
else:
|
||||
logging.info(f"Loading {lm}.fst.txt")
|
||||
with open(f"{lang_dir}/lm/{lm}.fst.txt") as f:
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
torch.save(G.as_dict(), f"{lang_dir}/lm/{lm}.pt")
|
||||
|
||||
first_token_disambig_id = lexicon.token_table["#0"]
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
|
||||
L = k2.arc_sort(L)
|
||||
G = k2.arc_sort(G)
|
||||
|
||||
logging.info("Intersecting L and G")
|
||||
LG = k2.compose(L, G)
|
||||
logging.info(f"LG shape: {LG.shape}")
|
||||
|
||||
logging.info("Connecting LG")
|
||||
LG = k2.connect(LG)
|
||||
logging.info(f"LG shape after k2.connect: {LG.shape}")
|
||||
|
||||
logging.info(type(LG.aux_labels))
|
||||
logging.info("Determinizing LG")
|
||||
|
||||
LG = k2.determinize(LG)
|
||||
logging.info(type(LG.aux_labels))
|
||||
|
||||
logging.info("Connecting LG after k2.determinize")
|
||||
LG = k2.connect(LG)
|
||||
|
||||
logging.info("Removing disambiguation symbols on LG")
|
||||
|
||||
# LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||
# see https://github.com/k2-fsa/k2/pull/1140
|
||||
labels = LG.labels
|
||||
labels[labels >= first_token_disambig_id] = 0
|
||||
LG.labels = labels
|
||||
|
||||
assert isinstance(LG.aux_labels, k2.RaggedTensor)
|
||||
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
|
||||
|
||||
LG = k2.remove_epsilon(LG)
|
||||
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
|
||||
|
||||
LG = k2.connect(LG)
|
||||
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
|
||||
|
||||
logging.info("Arc sorting LG")
|
||||
LG = k2.arc_sort(LG)
|
||||
|
||||
logging.info("Composing H and LG")
|
||||
# CAUTION: The name of the inner_labels is fixed
|
||||
# to `tokens`. If you want to change it, please
|
||||
# also change other places in icefall that are using
|
||||
# it.
|
||||
HLG = k2.compose(H, LG, inner_labels="tokens")
|
||||
|
||||
logging.info("Connecting LG")
|
||||
HLG = k2.connect(HLG)
|
||||
|
||||
logging.info("Arc sorting LG")
|
||||
HLG = k2.arc_sort(HLG)
|
||||
logging.info(f"HLG.shape: {HLG.shape}")
|
||||
|
||||
return HLG
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
|
||||
if (lang_dir / "HLG.pt").is_file():
|
||||
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
|
||||
return
|
||||
|
||||
logging.info(f"Processing {lang_dir}")
|
||||
|
||||
HLG = compile_HLG(lang_dir, args.lm)
|
||||
logging.info(f"Saving HLG.pt to {lang_dir}")
|
||||
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
@ -1 +0,0 @@
|
||||
../../../librispeech/ASR/local/compile_lg.py
|
149
egs/commonvoice/ASR/local/compile_lg.py
Executable file
149
egs/commonvoice/ASR/local/compile_lg.py
Executable file
@ -0,0 +1,149 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Kang Wei,
|
||||
# Zengrui Jin,)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This script takes as input lang_dir and generates LG from
|
||||
|
||||
- L, the lexicon, built from lang_dir/L_disambig.pt
|
||||
|
||||
Caution: We use a lexicon that contains disambiguation symbols
|
||||
|
||||
- G, the LM, built from lang_dir/lm/G_3_gram.fst.txt
|
||||
|
||||
The generated LG is saved in $lang_dir/LG.pt
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lm",
|
||||
type=str,
|
||||
default="G_3_gram",
|
||||
help="""Stem name for LM used in HLG compiling.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
|
||||
|
||||
Return:
|
||||
An FSA representing LG.
|
||||
"""
|
||||
lexicon = Lexicon(lang_dir)
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
|
||||
if Path(f"{lang_dir}/lm/{lm}.pt").is_file():
|
||||
logging.info(f"Loading pre-compiled {lm}")
|
||||
d = torch.load(f"{lang_dir}/lm/{lm}.pt")
|
||||
G = k2.Fsa.from_dict(d)
|
||||
else:
|
||||
logging.info(f"Loading {lm}.fst.txt")
|
||||
with open(f"{lang_dir}/lm/{lm}.fst.txt") as f:
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
torch.save(G.as_dict(), f"{lang_dir}/lm/{lm}.pt")
|
||||
|
||||
first_token_disambig_id = lexicon.token_table["#0"]
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
|
||||
L = k2.arc_sort(L)
|
||||
G = k2.arc_sort(G)
|
||||
|
||||
logging.info("Intersecting L and G")
|
||||
LG = k2.compose(L, G)
|
||||
logging.info(f"LG shape: {LG.shape}")
|
||||
|
||||
logging.info("Connecting LG")
|
||||
LG = k2.connect(LG)
|
||||
logging.info(f"LG shape after k2.connect: {LG.shape}")
|
||||
|
||||
logging.info(type(LG.aux_labels))
|
||||
logging.info("Determinizing LG")
|
||||
|
||||
LG = k2.determinize(LG, k2.DeterminizeWeightPushingType.kLogWeightPushing)
|
||||
logging.info(type(LG.aux_labels))
|
||||
|
||||
logging.info("Connecting LG after k2.determinize")
|
||||
LG = k2.connect(LG)
|
||||
|
||||
logging.info("Removing disambiguation symbols on LG")
|
||||
|
||||
# LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||
# see https://github.com/k2-fsa/k2/pull/1140
|
||||
labels = LG.labels
|
||||
labels[labels >= first_token_disambig_id] = 0
|
||||
LG.labels = labels
|
||||
|
||||
assert isinstance(LG.aux_labels, k2.RaggedTensor)
|
||||
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
|
||||
|
||||
LG = k2.remove_epsilon(LG)
|
||||
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
|
||||
|
||||
LG = k2.connect(LG)
|
||||
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
|
||||
|
||||
logging.info("Arc sorting LG")
|
||||
LG = k2.arc_sort(LG)
|
||||
|
||||
return LG
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
|
||||
if (lang_dir / "LG.pt").is_file():
|
||||
logging.info(f"{lang_dir}/LG.pt already exists - skipping")
|
||||
return
|
||||
|
||||
logging.info(f"Processing {lang_dir}")
|
||||
|
||||
LG = compile_LG(lang_dir, args.lm)
|
||||
logging.info(f"Saving LG.pt to {lang_dir}")
|
||||
torch.save(LG.as_dict(), f"{lang_dir}/LG.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
@ -48,8 +48,27 @@ def normalize_text(utt: str, language: str) -> str:
|
||||
utt = re.sub("’", "'", utt)
|
||||
if language == "en":
|
||||
return re.sub(r"[^a-zA-Z\s]", "", utt).upper()
|
||||
if language == "fr":
|
||||
elif language == "fr":
|
||||
return re.sub(r"[^A-ZÀÂÆÇÉÈÊËÎÏÔŒÙÛÜ' ]", "", utt).upper()
|
||||
elif language == "pl":
|
||||
return re.sub(r"[^a-ząćęłńóśźżA-ZĄĆĘŁŃÓŚŹŻ' ]", "", utt).upper()
|
||||
elif language == "yue":
|
||||
return (
|
||||
utt.replace(" ", "")
|
||||
.replace(",", "")
|
||||
.replace("。", " ")
|
||||
.replace("?", "")
|
||||
.replace("!", "")
|
||||
.replace("?", "")
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"""
|
||||
Text normalization not implemented for language: {language},
|
||||
please consider implementing it in the local/preprocess_commonvoice.py
|
||||
or raise an issue on GitHub to request it.
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def preprocess_commonvoice(
|
||||
|
@ -381,9 +381,11 @@ class CommonVoiceAsrDataModule:
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)(),
|
||||
input_strategy=(
|
||||
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)()
|
||||
),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
|
@ -79,6 +79,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -871,9 +872,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -31,7 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
@ -232,7 +232,7 @@ class CommonVoiceAsrDataModule:
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
@ -315,8 +315,8 @@ class CommonVoiceAsrDataModule:
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
@ -383,9 +383,11 @@ class CommonVoiceAsrDataModule:
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)(),
|
||||
input_strategy=(
|
||||
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)()
|
||||
),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
|
@ -889,9 +889,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise RuntimeError(f", exiting: {cur_grad_scale}")
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -81,6 +81,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -965,9 +966,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -78,6 +78,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
@ -888,9 +889,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -81,6 +81,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
@ -909,9 +910,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -81,6 +81,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
@ -908,9 +909,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -89,6 +89,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -1031,9 +1032,7 @@ def train_one_epoch(
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
|
@ -100,6 +100,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -371,9 +372,7 @@ def train_one_epoch(
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
|
@ -89,6 +89,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -1034,9 +1035,7 @@ def train_one_epoch(
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
|
@ -85,6 +85,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -1169,9 +1170,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -81,6 +81,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -1056,9 +1057,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -232,7 +232,7 @@ class LibriHeavyAsrDataModule:
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
@ -93,6 +93,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -1036,9 +1037,7 @@ def train_one_epoch(
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
|
@ -425,9 +425,11 @@ class LibriHeavyAsrDataModule:
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
input_strategy=(
|
||||
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures()
|
||||
),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
|
@ -103,6 +103,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -1051,9 +1052,7 @@ def train_one_epoch(
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
|
@ -117,6 +117,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -855,9 +856,9 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
decoder_dim=params.decoder_dim,
|
||||
joiner_dim=params.joiner_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
context_dim=4 * 768
|
||||
if params.context_injection
|
||||
else -1, # the output dim of text encoder
|
||||
context_dim=(
|
||||
4 * 768 if params.context_injection else -1
|
||||
), # the output dim of text encoder
|
||||
context_injection=params.context_injection,
|
||||
)
|
||||
return joiner
|
||||
@ -1398,9 +1399,7 @@ def train_one_epoch(
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
|
@ -36,6 +36,7 @@ The following table lists the differences among them.
|
||||
| `lstm_transducer_stateless3` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gradient filter + delay penalty |
|
||||
| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe |
|
||||
| `zipformer_adapter` | Upgraded Zipformer | Embedding + Conv1d | It supports domain adaptation of Zipformer using parameter efficient adapters |
|
||||
| `zipformer_adapter` | Upgraded Zipformer | Embedding + Conv1d | Finetune Zipformer with LoRA |
|
||||
|
||||
The decoder in `transducer_stateless` is modified from the paper
|
||||
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
|
||||
|
@ -80,6 +80,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -976,9 +977,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -81,6 +81,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -878,9 +879,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -81,6 +81,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -902,9 +903,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -77,6 +77,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -891,9 +892,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -80,6 +80,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
@ -880,9 +881,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -80,6 +80,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
@ -879,9 +880,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -84,6 +84,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
@ -946,9 +947,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -89,6 +89,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
@ -946,9 +947,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -479,18 +479,14 @@ class LibriSpeechAsrDataModule:
|
||||
@lru_cache()
|
||||
def gigaspeech_subset_small_cuts(self) -> CutSet:
|
||||
logging.info("About to get Gigaspeech subset-S cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_S.jsonl.gz")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_S.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def gigaspeech_dev_cuts(self) -> CutSet:
|
||||
logging.info("About to get Gigaspeech dev cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz"
|
||||
)
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def gigaspeech_test_cuts(self) -> CutSet:
|
||||
logging.info("About to get Gigaspeech test cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz"
|
||||
)
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz")
|
||||
|
@ -66,6 +66,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import UniqLexicon
|
||||
from icefall.utils import (
|
||||
@ -883,9 +884,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -92,6 +92,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -1122,9 +1123,7 @@ def train_one_epoch(
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
|
@ -90,6 +90,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -1021,9 +1022,7 @@ def train_one_epoch(
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
|
@ -81,6 +81,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -1125,9 +1126,7 @@ def train_one_epoch(
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
|
@ -62,6 +62,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
@ -797,9 +798,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
1
egs/librispeech/ASR/zipformer_lora/asr_datamodule.py
Symbolic link
1
egs/librispeech/ASR/zipformer_lora/asr_datamodule.py
Symbolic link
@ -0,0 +1 @@
|
||||
../tdnn_lstm_ctc/asr_datamodule.py
|
1
egs/librispeech/ASR/zipformer_lora/beam_search.py
Symbolic link
1
egs/librispeech/ASR/zipformer_lora/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/beam_search.py
|
1115
egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py
Executable file
1115
egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/librispeech/ASR/zipformer_lora/decoder.py
Symbolic link
1
egs/librispeech/ASR/zipformer_lora/decoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../zipformer/decoder.py
|
1
egs/librispeech/ASR/zipformer_lora/encoder_interface.py
Symbolic link
1
egs/librispeech/ASR/zipformer_lora/encoder_interface.py
Symbolic link
@ -0,0 +1 @@
|
||||
../transducer_stateless/encoder_interface.py
|
543
egs/librispeech/ASR/zipformer_lora/export.py
Executable file
543
egs/librispeech/ASR/zipformer_lora/export.py
Executable file
@ -0,0 +1,543 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao,
|
||||
# Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script converts several saved checkpoints
|
||||
# to a single one using model averaging.
|
||||
"""
|
||||
|
||||
Usage:
|
||||
|
||||
Note: This is a example for librispeech dataset, if you are using different
|
||||
dataset, you should change the argument values according to your dataset.
|
||||
|
||||
(1) Export to torchscript model using torch.jit.script()
|
||||
|
||||
- For non-streaming model:
|
||||
|
||||
./zipformer_lora/export.py \
|
||||
--exp-dir ./zipformer_lora/exp \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
--jit 1
|
||||
|
||||
It will generate a file `jit_script.pt` in the given `exp_dir`. You can later
|
||||
load it by `torch.jit.load("jit_script.pt")`.
|
||||
|
||||
Check ./jit_pretrained.py for its usage.
|
||||
|
||||
Check https://github.com/k2-fsa/sherpa
|
||||
for how to use the exported models outside of icefall.
|
||||
|
||||
- For streaming model:
|
||||
|
||||
./zipformer_lora/export.py \
|
||||
--exp-dir ./zipformer_lora/exp \
|
||||
--causal 1 \
|
||||
--chunk-size 16 \
|
||||
--left-context-frames 128 \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
--jit 1
|
||||
|
||||
It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`.
|
||||
You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`.
|
||||
|
||||
Check ./jit_pretrained_streaming.py for its usage.
|
||||
|
||||
Check https://github.com/k2-fsa/sherpa
|
||||
for how to use the exported models outside of icefall.
|
||||
|
||||
(2) Export `model.state_dict()`
|
||||
|
||||
- For non-streaming model:
|
||||
|
||||
./zipformer_lora/export.py \
|
||||
--exp-dir ./zipformer_lora/exp \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 30 \
|
||||
--avg 9
|
||||
|
||||
- For streaming model:
|
||||
|
||||
./zipformer_lora/export.py \
|
||||
--exp-dir ./zipformer_lora/exp \
|
||||
--causal 1 \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 30 \
|
||||
--avg 9
|
||||
|
||||
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
||||
load it by `icefall.checkpoint.load_checkpoint()`.
|
||||
|
||||
- For non-streaming model:
|
||||
|
||||
To use the generated file with `zipformer_lora/decode.py`,
|
||||
you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/librispeech/ASR
|
||||
./zipformer_lora/decode.py \
|
||||
--exp-dir ./zipformer_lora/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search \
|
||||
--bpe-model data/lang_bpe_500/bpe.model
|
||||
|
||||
- For streaming model:
|
||||
|
||||
To use the generated file with `zipformer_lora/decode.py` and `zipformer_lora/streaming_decode.py`, you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/librispeech/ASR
|
||||
|
||||
# simulated streaming decoding
|
||||
./zipformer_lora/decode.py \
|
||||
--exp-dir ./zipformer_lora/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 600 \
|
||||
--causal 1 \
|
||||
--chunk-size 16 \
|
||||
--left-context-frames 128 \
|
||||
--decoding-method greedy_search \
|
||||
--bpe-model data/lang_bpe_500/bpe.model
|
||||
|
||||
# chunk-wise streaming decoding
|
||||
./zipformer_lora/streaming_decode.py \
|
||||
--exp-dir ./zipformer_lora/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 600 \
|
||||
--causal 1 \
|
||||
--chunk-size 16 \
|
||||
--left-context-frames 128 \
|
||||
--decoding-method greedy_search \
|
||||
--bpe-model data/lang_bpe_500/bpe.model
|
||||
|
||||
Check ./pretrained.py for its usage.
|
||||
|
||||
Note: If you don't want to train a model from scratch, we have
|
||||
provided one for you. You can get it at
|
||||
|
||||
- non-streaming model:
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
|
||||
- streaming model:
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||
|
||||
with the following commands:
|
||||
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||
# You will find the pre-trained models in exp dir
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from finetune import add_finetune_arguments, add_model_arguments, get_model, get_params
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from torch import Tensor, nn
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import make_pad_mask, num_tokens, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipformer_lora/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/tokens.txt",
|
||||
help="Path to the tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to save a model after applying torch.jit.script.
|
||||
It will generate a file named jit_script.pt.
|
||||
Check ./jit_pretrained.py for how to use it.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
add_finetune_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class EncoderModel(nn.Module):
|
||||
"""A wrapper for encoder and encoder_embed"""
|
||||
|
||||
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.encoder_embed = encoder_embed
|
||||
|
||||
def forward(
|
||||
self, features: Tensor, feature_lengths: Tensor
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Args:
|
||||
features: (N, T, C)
|
||||
feature_lengths: (N,)
|
||||
"""
|
||||
x, x_lens = self.encoder_embed(features, feature_lengths)
|
||||
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
|
||||
class StreamingEncoderModel(nn.Module):
|
||||
"""A wrapper for encoder and encoder_embed"""
|
||||
|
||||
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
|
||||
super().__init__()
|
||||
assert len(encoder.chunk_size) == 1, encoder.chunk_size
|
||||
assert len(encoder.left_context_frames) == 1, encoder.left_context_frames
|
||||
self.chunk_size = encoder.chunk_size[0]
|
||||
self.left_context_len = encoder.left_context_frames[0]
|
||||
|
||||
# The encoder_embed subsample features (T - 7) // 2
|
||||
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
|
||||
self.pad_length = 7 + 2 * 3
|
||||
|
||||
self.encoder = encoder
|
||||
self.encoder_embed = encoder_embed
|
||||
|
||||
def forward(
|
||||
self, features: Tensor, feature_lengths: Tensor, states: List[Tensor]
|
||||
) -> Tuple[Tensor, Tensor, List[Tensor]]:
|
||||
"""Streaming forward for encoder_embed and encoder.
|
||||
|
||||
Args:
|
||||
features: (N, T, C)
|
||||
feature_lengths: (N,)
|
||||
states: a list of Tensors
|
||||
|
||||
Returns encoder outputs, output lengths, and updated states.
|
||||
"""
|
||||
chunk_size = self.chunk_size
|
||||
left_context_len = self.left_context_len
|
||||
|
||||
cached_embed_left_pad = states[-2]
|
||||
x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
|
||||
x=features,
|
||||
x_lens=feature_lengths,
|
||||
cached_left_pad=cached_embed_left_pad,
|
||||
)
|
||||
assert x.size(1) == chunk_size, (x.size(1), chunk_size)
|
||||
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
|
||||
# processed_mask is used to mask out initial states
|
||||
processed_mask = torch.arange(left_context_len, device=x.device).expand(
|
||||
x.size(0), left_context_len
|
||||
)
|
||||
processed_lens = states[-1] # (batch,)
|
||||
# (batch, left_context_size)
|
||||
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
|
||||
# Update processed lengths
|
||||
new_processed_lens = processed_lens + x_lens
|
||||
|
||||
# (batch, left_context_size + chunk_size)
|
||||
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
|
||||
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
encoder_states = states[:-2]
|
||||
|
||||
(
|
||||
encoder_out,
|
||||
encoder_out_lens,
|
||||
new_encoder_states,
|
||||
) = self.encoder.streaming_forward(
|
||||
x=x,
|
||||
x_lens=x_lens,
|
||||
states=encoder_states,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
new_states = new_encoder_states + [
|
||||
new_cached_embed_left_pad,
|
||||
new_processed_lens,
|
||||
]
|
||||
return encoder_out, encoder_out_lens, new_states
|
||||
|
||||
@torch.jit.export
|
||||
def get_init_states(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> List[torch.Tensor]:
|
||||
"""
|
||||
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
|
||||
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
|
||||
states[-2] is the cached left padding for ConvNeXt module,
|
||||
of shape (batch_size, num_channels, left_pad, num_freqs)
|
||||
states[-1] is processed_lens of shape (batch,), which records the number
|
||||
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
|
||||
"""
|
||||
states = self.encoder.get_init_states(batch_size, device)
|
||||
|
||||
embed_states = self.encoder_embed.get_init_states(batch_size, device)
|
||||
states.append(embed_states)
|
||||
|
||||
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||
states.append(processed_lens)
|
||||
|
||||
return states
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
# if torch.cuda.is_available():
|
||||
# device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||
params.blank_id = token_table["<blk>"]
|
||||
params.vocab_size = num_tokens(token_table) + 1
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
# merge the LoRA weights
|
||||
model.eval()
|
||||
|
||||
params.use_lora = False
|
||||
base_model = get_model(params)
|
||||
|
||||
new_state_dict = {}
|
||||
state_dict = model.state_dict()
|
||||
param_names = base_model.state_dict().keys()
|
||||
for k in param_names:
|
||||
assert k in state_dict.keys()
|
||||
new_state_dict[k] = state_dict[k]
|
||||
|
||||
base_model.load_state_dict(new_state_dict, strict=True)
|
||||
|
||||
model = base_model
|
||||
model.eval()
|
||||
|
||||
if params.jit is True:
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
|
||||
# Wrap encoder and encoder_embed as a module
|
||||
if params.causal:
|
||||
model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed)
|
||||
chunk_size = model.encoder.chunk_size
|
||||
left_context_len = model.encoder.left_context_len
|
||||
filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt"
|
||||
else:
|
||||
model.encoder = EncoderModel(model.encoder, model.encoder_embed)
|
||||
filename = "jit_script.pt"
|
||||
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
model.save(str(params.exp_dir / filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
else:
|
||||
logging.info("Not using torchscript. Export model.state_dict()")
|
||||
# Save it using a format so that it can be loaded
|
||||
# by :func:`load_checkpoint`
|
||||
filename = params.exp_dir / "pretrained.pt"
|
||||
torch.save({"model": model.state_dict()}, str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1553
egs/librispeech/ASR/zipformer_lora/finetune.py
Executable file
1553
egs/librispeech/ASR/zipformer_lora/finetune.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/librispeech/ASR/zipformer_lora/joiner.py
Symbolic link
1
egs/librispeech/ASR/zipformer_lora/joiner.py
Symbolic link
@ -0,0 +1 @@
|
||||
../zipformer/joiner.py
|
1
egs/librispeech/ASR/zipformer_lora/model.py
Symbolic link
1
egs/librispeech/ASR/zipformer_lora/model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../zipformer/model.py
|
1
egs/librispeech/ASR/zipformer_lora/optim.py
Symbolic link
1
egs/librispeech/ASR/zipformer_lora/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../zipformer/optim.py
|
2052
egs/librispeech/ASR/zipformer_lora/scaling.py
Normal file
2052
egs/librispeech/ASR/zipformer_lora/scaling.py
Normal file
File diff suppressed because it is too large
Load Diff
1
egs/librispeech/ASR/zipformer_lora/scaling_converter.py
Symbolic link
1
egs/librispeech/ASR/zipformer_lora/scaling_converter.py
Symbolic link
@ -0,0 +1 @@
|
||||
../zipformer/scaling_converter.py
|
1
egs/librispeech/ASR/zipformer_lora/subsampling.py
Symbolic link
1
egs/librispeech/ASR/zipformer_lora/subsampling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../zipformer/subsampling.py
|
1398
egs/librispeech/ASR/zipformer_lora/train.py
Executable file
1398
egs/librispeech/ASR/zipformer_lora/train.py
Executable file
File diff suppressed because it is too large
Load Diff
2522
egs/librispeech/ASR/zipformer_lora/zipformer.py
Normal file
2522
egs/librispeech/ASR/zipformer_lora/zipformer.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -79,6 +79,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import Lexicon, UniqLexicon
|
||||
from icefall.mmi import LFMMILoss
|
||||
@ -816,9 +817,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -31,6 +31,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--exp-dir conformer_ctc2/exp \
|
||||
--lang-dir data/lang_bpe_200 \
|
||||
--otc-token "<star>" \
|
||||
--feature-dim 768 \
|
||||
--allow-bypass-arc true \
|
||||
--allow-self-loop-arc true \
|
||||
--initial-bypass-weight -19 \
|
||||
@ -160,6 +161,14 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--feature-dim",
|
||||
type=int,
|
||||
default=768,
|
||||
help="""Number of features extracted in feature extraction stage.last dimension of feature vector.
|
||||
80 when using fbank features and 768 or 1024 whn using wave2vec""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--initial-lr",
|
||||
type=float,
|
||||
@ -385,7 +394,6 @@ def get_params() -> AttributeDict:
|
||||
"valid_interval": 800, # For the 100h subset, use 800
|
||||
"alignment_interval": 25,
|
||||
# parameters for conformer
|
||||
"feature_dim": 768,
|
||||
"subsampling_factor": 2,
|
||||
"encoder_dim": 512,
|
||||
"nhead": 8,
|
||||
|
@ -1,10 +1,10 @@
|
||||
# Introduction
|
||||
|
||||
This is a public domain speech dataset consisting of 13,100 short audio clips of a single speaker reading passages from 7 non-fiction books.
|
||||
A transcription is provided for each clip.
|
||||
This is a public domain speech dataset consisting of 13,100 short audio clips of a single speaker reading passages from 7 non-fiction books.
|
||||
A transcription is provided for each clip.
|
||||
Clips vary in length from 1 to 10 seconds and have a total length of approximately 24 hours.
|
||||
|
||||
The texts were published between 1884 and 1964, and are in the public domain.
|
||||
The texts were published between 1884 and 1964, and are in the public domain.
|
||||
The audio was recorded in 2016-17 by the [LibriVox](https://librivox.org/) project and is also in the public domain.
|
||||
|
||||
The above information is from the [LJSpeech website](https://keithito.com/LJ-Speech-Dataset/).
|
||||
@ -35,4 +35,69 @@ To inference, use:
|
||||
--exp-dir vits/exp \
|
||||
--epoch 1000 \
|
||||
--tokens data/tokens.txt
|
||||
```
|
||||
```
|
||||
|
||||
## Quality vs speed
|
||||
|
||||
If you feel that the trained model is slow at runtime, you can specify the
|
||||
argument `--model-type` during training. Possible values are:
|
||||
|
||||
- `low`, means **low** quality. The resulting model is very small in file size
|
||||
and runs very fast. The following is a wave file generatd by a `low` quality model
|
||||
|
||||
https://github.com/k2-fsa/icefall/assets/5284924/d5758c24-470d-40ee-b089-e57fcba81633
|
||||
|
||||
The text is `Ask not what your country can do for you; ask what you can do for your country.`
|
||||
|
||||
The exported onnx model has a file size of ``26.8 MB`` (float32).
|
||||
|
||||
- `medium`, means **medium** quality.
|
||||
The following is a wave file generatd by a `medium` quality model
|
||||
|
||||
https://github.com/k2-fsa/icefall/assets/5284924/b199d960-3665-4d0d-9ae9-a1bb69cbc8ac
|
||||
|
||||
The text is `Ask not what your country can do for you; ask what you can do for your country.`
|
||||
|
||||
The exported onnx model has a file size of ``70.9 MB`` (float32).
|
||||
|
||||
- `high`, means **high** quality. This is the default value.
|
||||
|
||||
The following is a wave file generatd by a `high` quality model
|
||||
|
||||
https://github.com/k2-fsa/icefall/assets/5284924/b39f3048-73a6-4267-bf95-df5abfdb28fc
|
||||
|
||||
The text is `Ask not what your country can do for you; ask what you can do for your country.`
|
||||
|
||||
The exported onnx model has a file size of ``113 MB`` (float32).
|
||||
|
||||
|
||||
A pre-trained `low` model trained using 4xV100 32GB GPU with the following command can be found at
|
||||
<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-low-2024-03-12>
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
./vits/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 1601 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir vits/exp \
|
||||
--model-type low \
|
||||
--max-duration 800
|
||||
```
|
||||
|
||||
A pre-trained `medium` model trained using 4xV100 32GB GPU with the following command can be found at
|
||||
<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-medium-2024-03-12>
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES=4,5,6,7
|
||||
./vits/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 1000 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir vits/exp-medium \
|
||||
--model-type medium \
|
||||
--max-duration 500
|
||||
|
||||
# (Note it is killed after `epoch-820.pt`)
|
||||
```
|
||||
|
@ -23,7 +23,11 @@ This file reads the texts in given manifest and save the new cuts with phoneme t
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import tacotron_cleaner.cleaners
|
||||
try:
|
||||
import tacotron_cleaner.cleaners
|
||||
except ModuleNotFoundError as ex:
|
||||
raise RuntimeError(f"{ex}\nPlease run\n pip install espnet_tts_frontend\n")
|
||||
|
||||
from lhotse import CutSet, load_manifest
|
||||
from piper_phonemize import phonemize_espeak
|
||||
|
||||
|
@ -28,7 +28,7 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||
log "Stage -1: build monotonic_align lib"
|
||||
if [ ! -d vits/monotonic_align/build ]; then
|
||||
cd vits/monotonic_align
|
||||
python setup.py build_ext --inplace
|
||||
python3 setup.py build_ext --inplace
|
||||
cd ../../
|
||||
else
|
||||
log "monotonic_align lib already built"
|
||||
@ -54,7 +54,7 @@ fi
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare LJSpeech manifest"
|
||||
# We assume that you have downloaded the LJSpeech corpus
|
||||
# to $dl_dir/LJSpeech
|
||||
# to $dl_dir/LJSpeech-1.1
|
||||
mkdir -p data/manifests
|
||||
if [ ! -e data/manifests/.ljspeech.done ]; then
|
||||
lhotse prepare ljspeech $dl_dir/LJSpeech-1.1 data/manifests
|
||||
@ -82,8 +82,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Prepare phoneme tokens for LJSpeech"
|
||||
# We assume you have installed piper_phonemize and espnet_tts_frontend.
|
||||
# If not, please install them with:
|
||||
# - piper_phonemize: refer to https://github.com/rhasspy/piper-phonemize,
|
||||
# could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
|
||||
# - piper_phonemize: pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html,
|
||||
# - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
|
||||
if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then
|
||||
./local/prepare_tokens_ljspeech.py
|
||||
|
@ -25,9 +25,8 @@ Export the model to ONNX:
|
||||
--exp-dir vits/exp \
|
||||
--tokens data/tokens.txt
|
||||
|
||||
It will generate two files inside vits/exp:
|
||||
It will generate one file inside vits/exp:
|
||||
- vits-epoch-1000.onnx
|
||||
- vits-epoch-1000.int8.onnx (quantizated model)
|
||||
|
||||
See ./test_onnx.py for how to use the exported ONNX models.
|
||||
"""
|
||||
@ -40,7 +39,6 @@ from typing import Dict, Tuple
|
||||
import onnx
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
from tokenizer import Tokenizer
|
||||
from train import get_model, get_params
|
||||
|
||||
@ -75,6 +73,16 @@ def get_parser():
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
type=str,
|
||||
default="high",
|
||||
choices=["low", "medium", "high"],
|
||||
help="""If not empty, valid values are: low, medium, high.
|
||||
It controls the model size. low -> runs faster.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -136,7 +144,7 @@ class OnnxModel(nn.Module):
|
||||
Return a tuple containing:
|
||||
- audio, generated wavform tensor, (B, T_wav)
|
||||
"""
|
||||
audio, _, _ = self.model.inference(
|
||||
audio, _, _ = self.model.generator.inference(
|
||||
text=tokens,
|
||||
text_lengths=tokens_lens,
|
||||
noise_scale=noise_scale,
|
||||
@ -198,6 +206,11 @@ def export_model_onnx(
|
||||
},
|
||||
)
|
||||
|
||||
if model.model.spks is None:
|
||||
num_speakers = 1
|
||||
else:
|
||||
num_speakers = model.model.spks
|
||||
|
||||
meta_data = {
|
||||
"model_type": "vits",
|
||||
"version": "1",
|
||||
@ -206,8 +219,8 @@ def export_model_onnx(
|
||||
"language": "English",
|
||||
"voice": "en-us", # Choose your language appropriately
|
||||
"has_espeak": 1,
|
||||
"n_speakers": 1,
|
||||
"sample_rate": 22050, # Must match the real sample rate
|
||||
"n_speakers": num_speakers,
|
||||
"sample_rate": model.model.sampling_rate, # Must match the real sample rate
|
||||
}
|
||||
logging.info(f"meta_data: {meta_data}")
|
||||
|
||||
@ -233,14 +246,13 @@ def main():
|
||||
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
|
||||
model = model.generator
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
model = OnnxModel(model=model)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"generator parameters: {num_param}")
|
||||
logging.info(f"generator parameters: {num_param}, or {num_param/1000/1000} M")
|
||||
|
||||
suffix = f"epoch-{params.epoch}"
|
||||
|
||||
@ -256,18 +268,6 @@ def main():
|
||||
)
|
||||
logging.info(f"Exported generator to {model_filename}")
|
||||
|
||||
# Generate int8 quantization models
|
||||
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||
|
||||
logging.info("Generate int8 quantization models")
|
||||
|
||||
model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=model_filename,
|
||||
model_output=model_filename_int8,
|
||||
weight_type=QuantType.QUInt8,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
@ -189,7 +189,7 @@ class VITSGenerator(torch.nn.Module):
|
||||
self.upsample_factor = int(np.prod(decoder_upsample_scales))
|
||||
self.spks = None
|
||||
if spks is not None and spks > 1:
|
||||
assert global_channels > 0
|
||||
assert global_channels > 0, global_channels
|
||||
self.spks = spks
|
||||
self.global_emb = torch.nn.Embedding(spks, global_channels)
|
||||
self.spk_embed_dim = None
|
||||
|
@ -72,6 +72,16 @@ def get_parser():
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
type=str,
|
||||
default="high",
|
||||
choices=["low", "medium", "high"],
|
||||
help="""If not empty, valid values are: low, medium, high.
|
||||
It controls the model size. low -> runs faster.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -94,6 +104,7 @@ def infer_dataset(
|
||||
tokenizer:
|
||||
Used to convert text to phonemes.
|
||||
"""
|
||||
|
||||
# Background worker save audios to disk.
|
||||
def _save_worker(
|
||||
batch_size: int,
|
||||
|
@ -10,7 +10,11 @@ import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from numba import njit, prange
|
||||
|
||||
try:
|
||||
from numba import njit, prange
|
||||
except ModuleNotFoundError as ex:
|
||||
raise RuntimeError(f"{ex}/nPlease run\n pip install numba")
|
||||
|
||||
try:
|
||||
from .core import maximum_path_c
|
||||
|
50
egs/ljspeech/TTS/vits/test_model.py
Executable file
50
egs/ljspeech/TTS/vits/test_model.py
Executable file
@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from tokenizer import Tokenizer
|
||||
from train import get_model, get_params
|
||||
from vits import VITS
|
||||
|
||||
|
||||
def test_model_type(model_type):
|
||||
tokens = "./data/tokens.txt"
|
||||
|
||||
params = get_params()
|
||||
|
||||
tokenizer = Tokenizer(tokens)
|
||||
params.blank_id = tokenizer.pad_id
|
||||
params.vocab_size = tokenizer.vocab_size
|
||||
params.model_type = model_type
|
||||
|
||||
model = get_model(params)
|
||||
generator = model.generator
|
||||
|
||||
num_param = sum([p.numel() for p in generator.parameters()])
|
||||
print(
|
||||
f"{model_type}: generator parameters: {num_param}, or {num_param/1000/1000} M"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
test_model_type("high") # 35.63 M
|
||||
test_model_type("low") # 7.55 M
|
||||
test_model_type("medium") # 23.61 M
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -54,6 +54,20 @@ def get_parser():
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--text",
|
||||
type=str,
|
||||
default="Ask not what your country can do for you; ask what you can do for your country.",
|
||||
help="Text to generate speech for",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-filename",
|
||||
type=str,
|
||||
default="test_onnx.wav",
|
||||
help="Filename to save the generated wave file.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -61,7 +75,7 @@ class OnnxModel:
|
||||
def __init__(self, model_filename: str):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 4
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.session_opts = session_opts
|
||||
|
||||
@ -72,6 +86,9 @@ class OnnxModel:
|
||||
)
|
||||
logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
|
||||
|
||||
metadata = self.model.get_modelmeta().custom_metadata_map
|
||||
self.sample_rate = int(metadata["sample_rate"])
|
||||
|
||||
def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -101,13 +118,14 @@ class OnnxModel:
|
||||
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
tokenizer = Tokenizer(args.tokens)
|
||||
|
||||
logging.info("About to create onnx model")
|
||||
model = OnnxModel(args.model_filename)
|
||||
|
||||
text = "I went there to see the land, the people and how their system works, end quote."
|
||||
text = args.text
|
||||
tokens = tokenizer.texts_to_token_ids(
|
||||
[text], intersperse_blank=True, add_sos=True, add_eos=True
|
||||
)
|
||||
@ -115,8 +133,9 @@ def main():
|
||||
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
|
||||
audio = model(tokens, tokens_lens) # (1, T')
|
||||
|
||||
torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050)
|
||||
logging.info("Saved to test_onnx.wav")
|
||||
output_filename = args.output_filename
|
||||
torchaudio.save(output_filename, audio, sample_rate=model.sample_rate)
|
||||
logging.info(f"Saved to {output_filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -92,9 +92,9 @@ class TextEncoder(torch.nn.Module):
|
||||
x_lengths (Tensor): Length tensor (B,).
|
||||
|
||||
Returns:
|
||||
Tensor: Encoded hidden representation (B, attention_dim, T_text).
|
||||
Tensor: Projected mean tensor (B, attention_dim, T_text).
|
||||
Tensor: Projected scale tensor (B, attention_dim, T_text).
|
||||
Tensor: Encoded hidden representation (B, embed_dim, T_text).
|
||||
Tensor: Projected mean tensor (B, embed_dim, T_text).
|
||||
Tensor: Projected scale tensor (B, embed_dim, T_text).
|
||||
Tensor: Mask tensor for input tensor (B, 1, T_text).
|
||||
|
||||
"""
|
||||
@ -108,6 +108,7 @@ class TextEncoder(torch.nn.Module):
|
||||
|
||||
# encoder assume the channel last (B, T_text, embed_dim)
|
||||
x = self.encoder(x, key_padding_mask=pad_mask)
|
||||
# Note: attention_dim == embed_dim
|
||||
|
||||
# convert the channel first (B, embed_dim, T_text)
|
||||
x = x.transpose(1, 2)
|
||||
|
@ -18,7 +18,15 @@ import logging
|
||||
from typing import Dict, List
|
||||
|
||||
import tacotron_cleaner.cleaners
|
||||
from piper_phonemize import phonemize_espeak
|
||||
|
||||
try:
|
||||
from piper_phonemize import phonemize_espeak
|
||||
except Exception as ex:
|
||||
raise RuntimeError(
|
||||
f"{ex}\nPlease run\n"
|
||||
"pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html"
|
||||
)
|
||||
|
||||
from utils import intersperse
|
||||
|
||||
|
||||
|
@ -153,6 +153,16 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
type=str,
|
||||
default="high",
|
||||
choices=["low", "medium", "high"],
|
||||
help="""If not empty, valid values are: low, medium, high.
|
||||
It controls the model size. low -> runs faster.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -189,15 +199,6 @@ def get_params() -> AttributeDict:
|
||||
|
||||
- feature_dim: The model input dim. It has to match the one used
|
||||
in computing features.
|
||||
|
||||
- subsampling_factor: The subsampling factor for the model.
|
||||
|
||||
- encoder_dim: Hidden dim for multi-head attention model.
|
||||
|
||||
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
||||
|
||||
- warm_step: The warmup period that dictates the decay of the
|
||||
scale on "simple" (un-pruned) loss.
|
||||
"""
|
||||
params = AttributeDict(
|
||||
{
|
||||
@ -278,6 +279,7 @@ def get_model(params: AttributeDict) -> nn.Module:
|
||||
vocab_size=params.vocab_size,
|
||||
feature_dim=params.feature_dim,
|
||||
sampling_rate=params.sampling_rate,
|
||||
model_type=params.model_type,
|
||||
mel_loss_params=mel_loss_params,
|
||||
lambda_adv=params.lambda_adv,
|
||||
lambda_mel=params.lambda_mel,
|
||||
@ -363,7 +365,7 @@ def train_one_epoch(
|
||||
model.train()
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
|
||||
# used to summary the stats over iterations in one epoch
|
||||
# used to track the stats over iterations in one epoch
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
saved_bad_model = False
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user