mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
resolve conflict
This commit is contained in:
commit
390f01653f
1
.github/scripts/.gitignore
vendored
Normal file
1
.github/scripts/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
piper_phonemize.html
|
29
.github/scripts/generate-piper-phonemize-page.py
vendored
Executable file
29
.github/scripts/generate-piper-phonemize-page.py
vendored
Executable file
@ -0,0 +1,29 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
prefix = (
|
||||||
|
"https://github.com/csukuangfj/piper-phonemize/releases/download/2023.12.5/"
|
||||||
|
)
|
||||||
|
files = [
|
||||||
|
"piper_phonemize-1.2.0-cp310-cp310-macosx_10_14_x86_64.whl",
|
||||||
|
"piper_phonemize-1.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
|
||||||
|
"piper_phonemize-1.2.0-cp311-cp311-macosx_10_14_x86_64.whl",
|
||||||
|
"piper_phonemize-1.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
|
||||||
|
"piper_phonemize-1.2.0-cp312-cp312-macosx_10_14_x86_64.whl",
|
||||||
|
"piper_phonemize-1.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
|
||||||
|
"piper_phonemize-1.2.0-cp37-cp37m-macosx_10_14_x86_64.whl",
|
||||||
|
"piper_phonemize-1.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
|
||||||
|
"piper_phonemize-1.2.0-cp38-cp38-macosx_10_14_x86_64.whl",
|
||||||
|
"piper_phonemize-1.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
|
||||||
|
"piper_phonemize-1.2.0-cp39-cp39-macosx_10_14_x86_64.whl",
|
||||||
|
"piper_phonemize-1.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
|
||||||
|
]
|
||||||
|
with open("piper_phonemize.html", "w") as f:
|
||||||
|
for file in files:
|
||||||
|
url = prefix + file
|
||||||
|
f.write(f'<a href="{url}">{file}</a><br/>\n')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
6
.github/scripts/librispeech/ASR/run.sh
vendored
6
.github/scripts/librispeech/ASR/run.sh
vendored
@ -15,9 +15,9 @@ function prepare_data() {
|
|||||||
# cause OOM error for CI later.
|
# cause OOM error for CI later.
|
||||||
mkdir -p download/lm
|
mkdir -p download/lm
|
||||||
pushd download/lm
|
pushd download/lm
|
||||||
wget -q http://www.openslr.org/resources/11/librispeech-vocab.txt
|
wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-lm-norm.txt.gz
|
||||||
wget -q http://www.openslr.org/resources/11/librispeech-lexicon.txt
|
wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-lexicon.txt
|
||||||
wget -q http://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz
|
wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-vocab.txt
|
||||||
ls -lh
|
ls -lh
|
||||||
gunzip librispeech-lm-norm.txt.gz
|
gunzip librispeech-lm-norm.txt.gz
|
||||||
|
|
||||||
|
157
.github/scripts/ljspeech/TTS/run.sh
vendored
Executable file
157
.github/scripts/ljspeech/TTS/run.sh
vendored
Executable file
@ -0,0 +1,157 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -ex
|
||||||
|
|
||||||
|
python3 -m pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html
|
||||||
|
python3 -m pip install espnet_tts_frontend
|
||||||
|
python3 -m pip install numba
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
cd egs/ljspeech/TTS
|
||||||
|
|
||||||
|
sed -i.bak s/600/8/g ./prepare.sh
|
||||||
|
sed -i.bak s/"first 100"/"first 3"/g ./prepare.sh
|
||||||
|
sed -i.bak s/500/5/g ./prepare.sh
|
||||||
|
git diff
|
||||||
|
|
||||||
|
function prepare_data() {
|
||||||
|
# We have created a subset of the data for testing
|
||||||
|
#
|
||||||
|
mkdir download
|
||||||
|
pushd download
|
||||||
|
wget -q https://huggingface.co/csukuangfj/ljspeech-subset-for-ci-test/resolve/main/LJSpeech-1.1.tar.bz2
|
||||||
|
tar xvf LJSpeech-1.1.tar.bz2
|
||||||
|
popd
|
||||||
|
|
||||||
|
./prepare.sh
|
||||||
|
tree .
|
||||||
|
}
|
||||||
|
|
||||||
|
function train() {
|
||||||
|
pushd ./vits
|
||||||
|
sed -i.bak s/200/3/g ./train.py
|
||||||
|
git diff .
|
||||||
|
popd
|
||||||
|
|
||||||
|
for t in low medium high; do
|
||||||
|
./vits/train.py \
|
||||||
|
--exp-dir vits/exp-$t \
|
||||||
|
--model-type $t \
|
||||||
|
--num-epochs 1 \
|
||||||
|
--save-every-n 1 \
|
||||||
|
--num-buckets 2 \
|
||||||
|
--tokens data/tokens.txt \
|
||||||
|
--max-duration 20
|
||||||
|
|
||||||
|
ls -lh vits/exp-$t
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
function infer() {
|
||||||
|
for t in low medium high; do
|
||||||
|
./vits/infer.py \
|
||||||
|
--num-buckets 2 \
|
||||||
|
--model-type $t \
|
||||||
|
--epoch 1 \
|
||||||
|
--exp-dir ./vits/exp-$t \
|
||||||
|
--tokens data/tokens.txt \
|
||||||
|
--max-duration 20
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
function export_onnx() {
|
||||||
|
for t in low medium high; do
|
||||||
|
./vits/export-onnx.py \
|
||||||
|
--model-type $t \
|
||||||
|
--epoch 1 \
|
||||||
|
--exp-dir ./vits/exp-$t \
|
||||||
|
--tokens data/tokens.txt
|
||||||
|
|
||||||
|
ls -lh vits/exp-$t/
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
function test_medium() {
|
||||||
|
git clone https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-medium-2024-03-12
|
||||||
|
|
||||||
|
./vits/export-onnx.py \
|
||||||
|
--model-type medium \
|
||||||
|
--epoch 820 \
|
||||||
|
--exp-dir ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp \
|
||||||
|
--tokens ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt
|
||||||
|
|
||||||
|
ls -lh ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp
|
||||||
|
|
||||||
|
./vits/test_onnx.py \
|
||||||
|
--model-filename ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp/vits-epoch-820.onnx \
|
||||||
|
--tokens ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt \
|
||||||
|
--output-filename /icefall/test-medium.wav
|
||||||
|
|
||||||
|
ls -lh /icefall/test-medium.wav
|
||||||
|
|
||||||
|
d=/icefall/vits-icefall-en_US-ljspeech-medium
|
||||||
|
mkdir $d
|
||||||
|
cp -v ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt $d/
|
||||||
|
cp -v ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp/vits-epoch-820.onnx $d/model.onnx
|
||||||
|
|
||||||
|
rm -rf icefall-tts-ljspeech-vits-medium-2024-03-12
|
||||||
|
|
||||||
|
pushd $d
|
||||||
|
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
|
||||||
|
tar xf espeak-ng-data.tar.bz2
|
||||||
|
rm espeak-ng-data.tar.bz2
|
||||||
|
cd ..
|
||||||
|
tar cjf vits-icefall-en_US-ljspeech-medium.tar.bz2 vits-icefall-en_US-ljspeech-medium
|
||||||
|
rm -rf vits-icefall-en_US-ljspeech-medium
|
||||||
|
ls -lh *.tar.bz2
|
||||||
|
popd
|
||||||
|
}
|
||||||
|
|
||||||
|
function test_low() {
|
||||||
|
git clone https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-low-2024-03-12
|
||||||
|
|
||||||
|
./vits/export-onnx.py \
|
||||||
|
--model-type low \
|
||||||
|
--epoch 1600 \
|
||||||
|
--exp-dir ./icefall-tts-ljspeech-vits-low-2024-03-12/exp \
|
||||||
|
--tokens ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt
|
||||||
|
|
||||||
|
ls -lh ./icefall-tts-ljspeech-vits-low-2024-03-12/exp
|
||||||
|
|
||||||
|
./vits/test_onnx.py \
|
||||||
|
--model-filename ./icefall-tts-ljspeech-vits-low-2024-03-12/exp/vits-epoch-1600.onnx \
|
||||||
|
--tokens ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt \
|
||||||
|
--output-filename /icefall/test-low.wav
|
||||||
|
|
||||||
|
ls -lh /icefall/test-low.wav
|
||||||
|
|
||||||
|
d=/icefall/vits-icefall-en_US-ljspeech-low
|
||||||
|
mkdir $d
|
||||||
|
cp -v ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt $d/
|
||||||
|
cp -v ./icefall-tts-ljspeech-vits-low-2024-03-12/exp/vits-epoch-1600.onnx $d/model.onnx
|
||||||
|
|
||||||
|
rm -rf icefall-tts-ljspeech-vits-low-2024-03-12
|
||||||
|
|
||||||
|
pushd $d
|
||||||
|
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
|
||||||
|
tar xf espeak-ng-data.tar.bz2
|
||||||
|
rm espeak-ng-data.tar.bz2
|
||||||
|
cd ..
|
||||||
|
tar cjf vits-icefall-en_US-ljspeech-low.tar.bz2 vits-icefall-en_US-ljspeech-low
|
||||||
|
rm -rf vits-icefall-en_US-ljspeech-low
|
||||||
|
ls -lh *.tar.bz2
|
||||||
|
popd
|
||||||
|
}
|
||||||
|
|
||||||
|
prepare_data
|
||||||
|
train
|
||||||
|
infer
|
||||||
|
export_onnx
|
||||||
|
rm -rf vits/exp-{low,medium,high}
|
||||||
|
test_medium
|
||||||
|
test_low
|
3
.github/workflows/build-doc.yml
vendored
3
.github/workflows/build-doc.yml
vendored
@ -56,11 +56,14 @@ jobs:
|
|||||||
- name: Build doc
|
- name: Build doc
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
.github/scripts/generate-piper-phonemize-page.py
|
||||||
cd docs
|
cd docs
|
||||||
python3 -m pip install -r ./requirements.txt
|
python3 -m pip install -r ./requirements.txt
|
||||||
make html
|
make html
|
||||||
touch build/html/.nojekyll
|
touch build/html/.nojekyll
|
||||||
|
|
||||||
|
cp -v ../piper_phonemize.html ./build/html/
|
||||||
|
|
||||||
- name: Deploy
|
- name: Deploy
|
||||||
uses: peaceiris/actions-gh-pages@v3
|
uses: peaceiris/actions-gh-pages@v3
|
||||||
with:
|
with:
|
||||||
|
2
.github/workflows/build-docker-image.yml
vendored
2
.github/workflows/build-docker-image.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
image: ["torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
|
image: ["torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
# refer to https://github.com/actions/checkout
|
# refer to https://github.com/actions/checkout
|
||||||
|
102
.github/workflows/ljspeech.yml
vendored
Normal file
102
.github/workflows/ljspeech.yml
vendored
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
name: ljspeech
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ljspeech-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
generate_build_matrix:
|
||||||
|
if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa'
|
||||||
|
# see https://github.com/pytorch/pytorch/pull/50633
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
outputs:
|
||||||
|
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
- name: Generating build matrix
|
||||||
|
id: set-matrix
|
||||||
|
run: |
|
||||||
|
# outputting for debugging purposes
|
||||||
|
python ./.github/scripts/docker/generate_build_matrix.py
|
||||||
|
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py)
|
||||||
|
echo "::set-output name=matrix::${MATRIX}"
|
||||||
|
|
||||||
|
ljspeech:
|
||||||
|
needs: generate_build_matrix
|
||||||
|
name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }}
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Free space
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
ls -lh
|
||||||
|
df -h
|
||||||
|
rm -rf /opt/hostedtoolcache
|
||||||
|
df -h
|
||||||
|
echo "pwd: $PWD"
|
||||||
|
echo "github.workspace ${{ github.workspace }}"
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
uses: addnab/docker-run-action@v3
|
||||||
|
with:
|
||||||
|
image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }}
|
||||||
|
options: |
|
||||||
|
--volume ${{ github.workspace }}/:/icefall
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PYTHONPATH=/icefall:$PYTHONPATH
|
||||||
|
cd /icefall
|
||||||
|
git config --global --add safe.directory /icefall
|
||||||
|
|
||||||
|
.github/scripts/ljspeech/TTS/run.sh
|
||||||
|
|
||||||
|
- name: display files
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
ls -lh
|
||||||
|
|
||||||
|
- uses: actions/upload-artifact@v4
|
||||||
|
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0'
|
||||||
|
with:
|
||||||
|
name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }}
|
||||||
|
path: ./*.wav
|
||||||
|
|
||||||
|
- uses: actions/upload-artifact@v4
|
||||||
|
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0'
|
||||||
|
with:
|
||||||
|
name: generated-models-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}
|
||||||
|
path: ./*.wav
|
||||||
|
|
||||||
|
- name: Release exported onnx models
|
||||||
|
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0'
|
||||||
|
uses: svenstaro/upload-release-action@v2
|
||||||
|
with:
|
||||||
|
file_glob: true
|
||||||
|
overwrite: true
|
||||||
|
file: vits-icefall-*.tar.bz2
|
||||||
|
repo_name: k2-fsa/sherpa-onnx
|
||||||
|
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
|
||||||
|
tag: tts-models
|
||||||
|
|
2
.github/workflows/run-docker-image.yml
vendored
2
.github/workflows/run-docker-image.yml
vendored
@ -14,7 +14,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
image: ["torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
|
image: ["torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
|
||||||
steps:
|
steps:
|
||||||
# refer to https://github.com/actions/checkout
|
# refer to https://github.com/actions/checkout
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
|
8
.github/workflows/style_check.yml
vendored
8
.github/workflows/style_check.yml
vendored
@ -49,7 +49,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install Python dependencies
|
- name: Install Python dependencies
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0
|
python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0 isort==5.10.1
|
||||||
# Click issue fixed in https://github.com/psf/black/pull/2966
|
# Click issue fixed in https://github.com/psf/black/pull/2966
|
||||||
|
|
||||||
- name: Run flake8
|
- name: Run flake8
|
||||||
@ -67,3 +67,9 @@ jobs:
|
|||||||
working-directory: ${{github.workspace}}
|
working-directory: ${{github.workspace}}
|
||||||
run: |
|
run: |
|
||||||
black --check --diff .
|
black --check --diff .
|
||||||
|
|
||||||
|
- name: Run isort
|
||||||
|
shell: bash
|
||||||
|
working-directory: ${{github.workspace}}
|
||||||
|
run: |
|
||||||
|
isort --check --diff .
|
||||||
|
@ -26,7 +26,7 @@ repos:
|
|||||||
# E121,E123,E126,E226,E24,E704,W503,W504
|
# E121,E123,E126,E226,E24,E704,W503,W504
|
||||||
|
|
||||||
- repo: https://github.com/pycqa/isort
|
- repo: https://github.com/pycqa/isort
|
||||||
rev: 5.11.5
|
rev: 5.10.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
args: ["--profile=black"]
|
args: ["--profile=black"]
|
||||||
|
@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
|
|||||||
ARG DEBIAN_FRONTEND=noninteractive
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
# python 3.7
|
# python 3.7
|
||||||
ARG K2_VERSION="1.24.4.dev20240211+cuda11.3.torch1.12.1"
|
ARG K2_VERSION="1.24.4.dev20240223+cuda11.3.torch1.12.1"
|
||||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.3.torch1.12.1"
|
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.3.torch1.12.1"
|
||||||
ARG TORCHAUDIO_VERSION="0.12.1+cu113"
|
ARG TORCHAUDIO_VERSION="0.12.1+cu113"
|
||||||
|
|
||||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||||
|
@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
|
|||||||
ARG DEBIAN_FRONTEND=noninteractive
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
# python 3.9
|
# python 3.9
|
||||||
ARG K2_VERSION="1.24.4.dev20240211+cuda11.6.torch1.13.0"
|
ARG K2_VERSION="1.24.4.dev20240223+cuda11.6.torch1.13.0"
|
||||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.6.torch1.13.0"
|
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.6.torch1.13.0"
|
||||||
ARG TORCHAUDIO_VERSION="0.13.0+cu116"
|
ARG TORCHAUDIO_VERSION="0.13.0+cu116"
|
||||||
|
|
||||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||||
|
@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
|
|||||||
ARG DEBIAN_FRONTEND=noninteractive
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
# python 3.7
|
# python 3.7
|
||||||
ARG K2_VERSION="1.24.4.dev20240211+cuda10.2.torch1.9.0"
|
ARG K2_VERSION="1.24.4.dev20240223+cuda10.2.torch1.9.0"
|
||||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda10.2.torch1.9.0"
|
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda10.2.torch1.9.0"
|
||||||
ARG TORCHAUDIO_VERSION="0.9.0"
|
ARG TORCHAUDIO_VERSION="0.9.0"
|
||||||
|
|
||||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||||
|
@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
|
|||||||
ARG DEBIAN_FRONTEND=noninteractive
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
# python 3.10
|
# python 3.10
|
||||||
ARG K2_VERSION="1.24.4.dev20240211+cuda11.7.torch2.0.0"
|
ARG K2_VERSION="1.24.4.dev20240223+cuda11.7.torch2.0.0"
|
||||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.7.torch2.0.0"
|
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.7.torch2.0.0"
|
||||||
ARG TORCHAUDIO_VERSION="2.0.0+cu117"
|
ARG TORCHAUDIO_VERSION="2.0.0+cu117"
|
||||||
|
|
||||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||||
|
@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
|
|||||||
ARG DEBIAN_FRONTEND=noninteractive
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
# python 3.10
|
# python 3.10
|
||||||
ARG K2_VERSION="1.24.4.dev20240211+cuda11.8.torch2.1.0"
|
ARG K2_VERSION="1.24.4.dev20240223+cuda11.8.torch2.1.0"
|
||||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.8.torch2.1.0"
|
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.8.torch2.1.0"
|
||||||
ARG TORCHAUDIO_VERSION="2.1.0+cu118"
|
ARG TORCHAUDIO_VERSION="2.1.0+cu118"
|
||||||
|
|
||||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||||
|
@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
|
|||||||
ARG DEBIAN_FRONTEND=noninteractive
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
# python 3.10
|
# python 3.10
|
||||||
ARG K2_VERSION="1.24.4.dev20240211+cuda12.1.torch2.1.0"
|
ARG K2_VERSION="1.24.4.dev20240223+cuda12.1.torch2.1.0"
|
||||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda12.1.torch2.1.0"
|
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda12.1.torch2.1.0"
|
||||||
ARG TORCHAUDIO_VERSION="2.1.0+cu121"
|
ARG TORCHAUDIO_VERSION="2.1.0+cu121"
|
||||||
|
|
||||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||||
|
@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
|
|||||||
ARG DEBIAN_FRONTEND=noninteractive
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
# python 3.10
|
# python 3.10
|
||||||
ARG K2_VERSION="1.24.4.dev20240211+cuda11.8.torch2.2.0"
|
ARG K2_VERSION="1.24.4.dev20240223+cuda11.8.torch2.2.0"
|
||||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.8.torch2.2.0"
|
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.8.torch2.2.0"
|
||||||
ARG TORCHAUDIO_VERSION="2.2.0+cu118"
|
ARG TORCHAUDIO_VERSION="2.2.0+cu118"
|
||||||
|
|
||||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||||
|
@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
|
|||||||
ARG DEBIAN_FRONTEND=noninteractive
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
# python 3.10
|
# python 3.10
|
||||||
ARG K2_VERSION="1.24.4.dev20240211+cuda12.1.torch2.2.0"
|
ARG K2_VERSION="1.24.4.dev20240223+cuda12.1.torch2.2.0"
|
||||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda12.1.torch2.2.0"
|
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda12.1.torch2.2.0"
|
||||||
ARG TORCHAUDIO_VERSION="2.2.0+cu121"
|
ARG TORCHAUDIO_VERSION="2.2.0+cu121"
|
||||||
|
|
||||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||||
|
70
docker/torch2.2.1-cuda11.8.dockerfile
Normal file
70
docker/torch2.2.1-cuda11.8.dockerfile
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
FROM pytorch/pytorch:2.2.1-cuda11.8-cudnn8-devel
|
||||||
|
|
||||||
|
ENV LC_ALL C.UTF-8
|
||||||
|
|
||||||
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
# python 3.10
|
||||||
|
ARG K2_VERSION="1.24.4.dev20240223+cuda11.8.torch2.2.1"
|
||||||
|
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.8.torch2.2.1"
|
||||||
|
ARG TORCHAUDIO_VERSION="2.2.1+cu118"
|
||||||
|
|
||||||
|
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||||
|
LABEL k2_version=${K2_VERSION}
|
||||||
|
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
|
||||||
|
LABEL github_repo="https://github.com/k2-fsa/icefall"
|
||||||
|
|
||||||
|
RUN apt-get update && \
|
||||||
|
apt-get install -y --no-install-recommends \
|
||||||
|
curl \
|
||||||
|
vim \
|
||||||
|
libssl-dev \
|
||||||
|
autoconf \
|
||||||
|
automake \
|
||||||
|
bzip2 \
|
||||||
|
ca-certificates \
|
||||||
|
ffmpeg \
|
||||||
|
g++ \
|
||||||
|
gfortran \
|
||||||
|
git \
|
||||||
|
libtool \
|
||||||
|
make \
|
||||||
|
patch \
|
||||||
|
sox \
|
||||||
|
subversion \
|
||||||
|
unzip \
|
||||||
|
valgrind \
|
||||||
|
wget \
|
||||||
|
zlib1g-dev \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
RUN pip install --no-cache-dir \
|
||||||
|
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
|
||||||
|
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
|
||||||
|
git+https://github.com/lhotse-speech/lhotse \
|
||||||
|
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
|
||||||
|
kaldi_native_io \
|
||||||
|
kaldialign \
|
||||||
|
kaldifst \
|
||||||
|
kaldilm \
|
||||||
|
sentencepiece>=0.1.96 \
|
||||||
|
tensorboard \
|
||||||
|
typeguard \
|
||||||
|
dill \
|
||||||
|
onnx \
|
||||||
|
onnxruntime \
|
||||||
|
onnxmltools \
|
||||||
|
multi_quantization \
|
||||||
|
typeguard \
|
||||||
|
numpy \
|
||||||
|
pytest \
|
||||||
|
graphviz
|
||||||
|
|
||||||
|
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
|
||||||
|
cd /workspace/icefall && \
|
||||||
|
pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
|
||||||
|
|
||||||
|
WORKDIR /workspace/icefall
|
70
docker/torch2.2.1-cuda12.1.dockerfile
Normal file
70
docker/torch2.2.1-cuda12.1.dockerfile
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
FROM pytorch/pytorch:2.2.1-cuda12.1-cudnn8-devel
|
||||||
|
|
||||||
|
ENV LC_ALL C.UTF-8
|
||||||
|
|
||||||
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
# python 3.10
|
||||||
|
ARG K2_VERSION="1.24.4.dev20240223+cuda12.1.torch2.2.1"
|
||||||
|
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda12.1.torch2.2.1"
|
||||||
|
ARG TORCHAUDIO_VERSION="2.2.1+cu121"
|
||||||
|
|
||||||
|
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||||
|
LABEL k2_version=${K2_VERSION}
|
||||||
|
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
|
||||||
|
LABEL github_repo="https://github.com/k2-fsa/icefall"
|
||||||
|
|
||||||
|
RUN apt-get update && \
|
||||||
|
apt-get install -y --no-install-recommends \
|
||||||
|
curl \
|
||||||
|
vim \
|
||||||
|
libssl-dev \
|
||||||
|
autoconf \
|
||||||
|
automake \
|
||||||
|
bzip2 \
|
||||||
|
ca-certificates \
|
||||||
|
ffmpeg \
|
||||||
|
g++ \
|
||||||
|
gfortran \
|
||||||
|
git \
|
||||||
|
libtool \
|
||||||
|
make \
|
||||||
|
patch \
|
||||||
|
sox \
|
||||||
|
subversion \
|
||||||
|
unzip \
|
||||||
|
valgrind \
|
||||||
|
wget \
|
||||||
|
zlib1g-dev \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
RUN pip install --no-cache-dir \
|
||||||
|
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
|
||||||
|
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
|
||||||
|
git+https://github.com/lhotse-speech/lhotse \
|
||||||
|
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
|
||||||
|
kaldi_native_io \
|
||||||
|
kaldialign \
|
||||||
|
kaldifst \
|
||||||
|
kaldilm \
|
||||||
|
sentencepiece>=0.1.96 \
|
||||||
|
tensorboard \
|
||||||
|
typeguard \
|
||||||
|
dill \
|
||||||
|
onnx \
|
||||||
|
onnxruntime \
|
||||||
|
onnxmltools \
|
||||||
|
multi_quantization \
|
||||||
|
typeguard \
|
||||||
|
numpy \
|
||||||
|
pytest \
|
||||||
|
graphviz
|
||||||
|
|
||||||
|
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
|
||||||
|
cd /workspace/icefall && \
|
||||||
|
pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
|
||||||
|
|
||||||
|
WORKDIR /workspace/icefall
|
@ -34,6 +34,8 @@ which will give you something like below:
|
|||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
|
"torch2.2.1-cuda12.1"
|
||||||
|
"torch2.2.1-cuda11.8"
|
||||||
"torch2.2.0-cuda12.1"
|
"torch2.2.0-cuda12.1"
|
||||||
"torch2.2.0-cuda11.8"
|
"torch2.2.0-cuda11.8"
|
||||||
"torch2.1.0-cuda12.1"
|
"torch2.1.0-cuda12.1"
|
||||||
|
225
docs/source/recipes/Finetune/adapter/finetune_adapter.rst
Normal file
225
docs/source/recipes/Finetune/adapter/finetune_adapter.rst
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
Finetune from a pre-trained Zipformer model with adapters
|
||||||
|
=========================================================
|
||||||
|
|
||||||
|
This tutorial shows you how to fine-tune a pre-trained **Zipformer**
|
||||||
|
transducer model on a new dataset with adapters.
|
||||||
|
Adapters are compact and efficient module that can be integrated into a pre-trained model
|
||||||
|
to improve the model's performance on a new domain. Adapters are injected
|
||||||
|
between different modules in the well-trained neural network. During training, only the parameters
|
||||||
|
in the adapters will be updated. It achieves competitive performance
|
||||||
|
while requiring much less GPU memory than full fine-tuning. For more details about adapters,
|
||||||
|
please refer to the original `paper <https://arxiv.org/pdf/1902.00751.pdf#/>`_ for more details.
|
||||||
|
|
||||||
|
.. HINT::
|
||||||
|
|
||||||
|
We assume you have read the page :ref:`install icefall` and have setup
|
||||||
|
the environment for ``icefall``.
|
||||||
|
|
||||||
|
.. HINT::
|
||||||
|
|
||||||
|
We recommend you to use a GPU or several GPUs to run this recipe
|
||||||
|
|
||||||
|
For illustration purpose, we fine-tune the Zipformer transducer model
|
||||||
|
pre-trained on `LibriSpeech`_ on the small subset of `GigaSpeech`_. You could use your
|
||||||
|
own data for fine-tuning if you create a manifest for your new dataset.
|
||||||
|
|
||||||
|
Data preparation
|
||||||
|
----------------
|
||||||
|
|
||||||
|
Please follow the instructions in the `GigaSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/gigaspeech/ASR>`_
|
||||||
|
to prepare the fine-tune data used in this tutorial. We only require the small subset in GigaSpeech for this tutorial.
|
||||||
|
|
||||||
|
|
||||||
|
Model preparation
|
||||||
|
-----------------
|
||||||
|
|
||||||
|
We are using the Zipformer model trained on full LibriSpeech (960 hours) as the intialization. The
|
||||||
|
checkpoint of the model can be downloaded via the following command:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||||
|
$ cd icefall-asr-librispeech-zipformer-2023-05-15/exp
|
||||||
|
$ git lfs pull --include "pretrained.pt"
|
||||||
|
$ ln -s pretrained.pt epoch-99.pt
|
||||||
|
$ cd ../data/lang_bpe_500
|
||||||
|
$ git lfs pull --include bpe.model
|
||||||
|
$ cd ../../..
|
||||||
|
|
||||||
|
Before fine-tuning, let's test the model's WER on the new domain. The following command performs
|
||||||
|
decoding on the GigaSpeech test sets:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
./zipformer/decode_gigaspeech.py \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir icefall-asr-librispeech-zipformer-2023-05-15/exp \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--max-duration 1000 \
|
||||||
|
--decoding-method greedy_search
|
||||||
|
|
||||||
|
You should see the following numbers:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
For dev, WER of different settings are:
|
||||||
|
greedy_search 20.06 best for dev
|
||||||
|
|
||||||
|
For test, WER of different settings are:
|
||||||
|
greedy_search 19.27 best for test
|
||||||
|
|
||||||
|
|
||||||
|
Fine-tune with adapter
|
||||||
|
----------------------
|
||||||
|
|
||||||
|
We insert 4 adapters with residual connection in each ``Zipformer2EncoderLayer``.
|
||||||
|
The original model parameters remain untouched during training and only the parameters of
|
||||||
|
the adapters are updated. The following command starts a fine-tuning experiment with adapters:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ do_finetune=1
|
||||||
|
$ use_adapters=1
|
||||||
|
$ adapter_dim=8
|
||||||
|
|
||||||
|
$ ./zipformer_adapter/train.py \
|
||||||
|
--world-size 2 \
|
||||||
|
--num-epochs 20 \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--base-lr 0.045 \
|
||||||
|
--use-adapters $use_adapters --adapter-dim $adapter_dim \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--do-finetune $do_finetune \
|
||||||
|
--master-port 13022 \
|
||||||
|
--finetune-ckpt icefall-asr-librispeech-zipformer-2023-05-15/exp/pretrained.pt \
|
||||||
|
--max-duration 1000
|
||||||
|
|
||||||
|
The following arguments are related to fine-tuning:
|
||||||
|
|
||||||
|
- ``--do-finetune``
|
||||||
|
If True, do fine-tuning by initializing the model from a pre-trained checkpoint.
|
||||||
|
**Note that if you want to resume your fine-tuning experiment from certain epochs, you
|
||||||
|
need to set this to False.**
|
||||||
|
|
||||||
|
- ``use-adapters``
|
||||||
|
If adapters are used during fine-tuning.
|
||||||
|
|
||||||
|
- ``--adapter-dim``
|
||||||
|
The bottleneck dimension of the adapter module. Typically a small number.
|
||||||
|
|
||||||
|
You should notice that in the training log, the total number of trainale parameters is shown:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
2024-02-22 21:22:03,808 INFO [train.py:1277] A total of 761344 trainable parameters (1.148% of the whole model)
|
||||||
|
|
||||||
|
The trainable parameters only makes up 1.15% of the entire model parameters, so the training will be much faster
|
||||||
|
and requires less memory than full fine-tuning.
|
||||||
|
|
||||||
|
|
||||||
|
Decoding
|
||||||
|
--------
|
||||||
|
|
||||||
|
After training, let's test the WERs. To test the WERs on the GigaSpeech set,
|
||||||
|
you can execute the following command:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ epoch=20
|
||||||
|
$ avg=10
|
||||||
|
$ use_adapters=1
|
||||||
|
$ adapter_dim=8
|
||||||
|
|
||||||
|
% ./zipformer/decode.py \
|
||||||
|
--epoch $epoch \
|
||||||
|
--avg $avg \
|
||||||
|
--use-averaged-model 1 \
|
||||||
|
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
|
||||||
|
--max-duration 600 \
|
||||||
|
--use-adapters $use_adapters \
|
||||||
|
--adapter-dim $adapter_dim \
|
||||||
|
--decoding-method greedy_search
|
||||||
|
|
||||||
|
You should see the following numbers:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
For dev, WER of different settings are:
|
||||||
|
greedy_search 15.44 best for dev
|
||||||
|
|
||||||
|
For test, WER of different settings are:
|
||||||
|
greedy_search 15.42 best for test
|
||||||
|
|
||||||
|
|
||||||
|
The WER on test set is improved from 19.27 to 15.42, demonstrating the effectiveness of adapters.
|
||||||
|
|
||||||
|
The same model can be used to perform decoding on LibriSpeech test sets. You can deactivate the adapters
|
||||||
|
to keep the same performance of the original model:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ epoch=20
|
||||||
|
$ avg=1
|
||||||
|
$ use_adapters=0
|
||||||
|
$ adapter_dim=8
|
||||||
|
|
||||||
|
% ./zipformer/decode.py \
|
||||||
|
--epoch $epoch \
|
||||||
|
--avg $avg \
|
||||||
|
--use-averaged-model 1 \
|
||||||
|
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
|
||||||
|
--max-duration 600 \
|
||||||
|
--use-adapters $use_adapters \
|
||||||
|
--adapter-dim $adapter_dim \
|
||||||
|
--decoding-method greedy_search
|
||||||
|
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
For dev, WER of different settings are:
|
||||||
|
greedy_search 2.23 best for test-clean
|
||||||
|
|
||||||
|
For test, WER of different settings are:
|
||||||
|
greedy_search 4.96 best for test-other
|
||||||
|
|
||||||
|
The numbers are the same as reported in `icefall <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md#normal-scaled-model-number-of-model-parameters-65549011-ie-6555-m>`_. So adapter-based
|
||||||
|
fine-tuning is also very flexible as the same model can be used for decoding on the original and target domain.
|
||||||
|
|
||||||
|
|
||||||
|
Export the model
|
||||||
|
----------------
|
||||||
|
|
||||||
|
After training, the model can be exported to ``onnx`` format easily using the following command:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ use_adapters=1
|
||||||
|
$ adapter_dim=16
|
||||||
|
|
||||||
|
$ ./zipformer_adapter/export-onnx.py \
|
||||||
|
--tokens icefall-asr-librispeech-zipformer-2023-05-15/data/lang_bpe_500/tokens.txt \
|
||||||
|
--use-averaged-model 1 \
|
||||||
|
--epoch 20 \
|
||||||
|
--avg 10 \
|
||||||
|
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
|
||||||
|
--use-adapters $use_adapters \
|
||||||
|
--adapter-dim $adapter_dim \
|
||||||
|
--num-encoder-layers "2,2,3,4,3,2" \
|
||||||
|
--downsampling-factor "1,2,4,8,4,2" \
|
||||||
|
--feedforward-dim "512,768,1024,1536,1024,768" \
|
||||||
|
--num-heads "4,4,4,8,4,4" \
|
||||||
|
--encoder-dim "192,256,384,512,384,256" \
|
||||||
|
--query-head-dim 32 \
|
||||||
|
--value-head-dim 12 \
|
||||||
|
--pos-head-dim 4 \
|
||||||
|
--pos-dim 48 \
|
||||||
|
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
||||||
|
--cnn-module-kernel "31,31,15,15,15,31" \
|
||||||
|
--decoder-dim 512 \
|
||||||
|
--joiner-dim 512 \
|
||||||
|
--causal False \
|
||||||
|
--chunk-size "16,32,64,-1" \
|
||||||
|
--left-context-frames "64,128,256,-1"
|
@ -13,3 +13,4 @@ data to improve the performance on new domains.
|
|||||||
:caption: Table of Contents
|
:caption: Table of Contents
|
||||||
|
|
||||||
from_supervised/finetune_zipformer
|
from_supervised/finetune_zipformer
|
||||||
|
adapter/finetune_adapter
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
VITS
|
VITS-LJSpeech
|
||||||
===============
|
===============
|
||||||
|
|
||||||
This tutorial shows you how to train an VITS model
|
This tutorial shows you how to train an VITS model
|
||||||
@ -13,6 +13,14 @@ with the `LJSpeech <https://keithito.com/LJ-Speech-Dataset/>`_ dataset.
|
|||||||
The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech <https://arxiv.org/pdf/2106.06103.pdf>`_
|
The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech <https://arxiv.org/pdf/2106.06103.pdf>`_
|
||||||
|
|
||||||
|
|
||||||
|
Install extra dependencies
|
||||||
|
--------------------------
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html
|
||||||
|
pip install numba espnet_tts_frontend
|
||||||
|
|
||||||
Data preparation
|
Data preparation
|
||||||
----------------
|
----------------
|
||||||
|
|
||||||
@ -56,7 +64,8 @@ Training
|
|||||||
--start-epoch 1 \
|
--start-epoch 1 \
|
||||||
--use-fp16 1 \
|
--use-fp16 1 \
|
||||||
--exp-dir vits/exp \
|
--exp-dir vits/exp \
|
||||||
--tokens data/tokens.txt
|
--tokens data/tokens.txt \
|
||||||
|
--model-type high \
|
||||||
--max-duration 500
|
--max-duration 500
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
@ -64,6 +73,11 @@ Training
|
|||||||
You can adjust the hyper-parameters to control the size of the VITS model and
|
You can adjust the hyper-parameters to control the size of the VITS model and
|
||||||
the training configurations. For more details, please run ``./vits/train.py --help``.
|
the training configurations. For more details, please run ``./vits/train.py --help``.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
If you want a model that runs faster on CPU, please use ``--model-type low``
|
||||||
|
or ``--model-type medium``.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
The training can take a long time (usually a couple of days).
|
The training can take a long time (usually a couple of days).
|
||||||
@ -95,8 +109,8 @@ training part first. It will save the ground-truth and generated wavs to the dir
|
|||||||
Export models
|
Export models
|
||||||
-------------
|
-------------
|
||||||
|
|
||||||
Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``:
|
Currently we only support ONNX model exporting. It will generate one file in the given ``exp-dir``:
|
||||||
``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``.
|
``vits-epoch-*.onnx``.
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
@ -120,4 +134,68 @@ Download pretrained models
|
|||||||
If you don't want to train from scratch, you can download the pretrained models
|
If you don't want to train from scratch, you can download the pretrained models
|
||||||
by visiting the following link:
|
by visiting the following link:
|
||||||
|
|
||||||
- `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2023-11-29>`_
|
- ``--model-type=high``: `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2024-02-28>`_
|
||||||
|
- ``--model-type=medium``: `<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-medium-2024-03-12>`_
|
||||||
|
- ``--model-type=low``: `<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-low-2024-03-12>`_
|
||||||
|
|
||||||
|
Usage in sherpa-onnx
|
||||||
|
--------------------
|
||||||
|
|
||||||
|
The following describes how to test the exported ONNX model in `sherpa-onnx`_.
|
||||||
|
|
||||||
|
.. hint::
|
||||||
|
|
||||||
|
`sherpa-onnx`_ supports different programming languages, e.g., C++, C, Python,
|
||||||
|
Kotlin, Java, Swift, Go, C#, etc. It also supports Android and iOS.
|
||||||
|
|
||||||
|
We only describe how to use pre-built binaries from `sherpa-onnx`_ below.
|
||||||
|
Please refer to `<https://k2-fsa.github.io/sherpa/onnx/>`_
|
||||||
|
for more documentation.
|
||||||
|
|
||||||
|
Install sherpa-onnx
|
||||||
|
^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
pip install sherpa-onnx
|
||||||
|
|
||||||
|
To check that you have installed `sherpa-onnx`_ successfully, please run:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
which sherpa-onnx-offline-tts
|
||||||
|
sherpa-onnx-offline-tts --help
|
||||||
|
|
||||||
|
Download lexicon files
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
cd /tmp
|
||||||
|
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
|
||||||
|
tar xf espeak-ng-data.tar.bz2
|
||||||
|
|
||||||
|
Run sherpa-onnx
|
||||||
|
^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
cd egs/ljspeech/TTS
|
||||||
|
|
||||||
|
sherpa-onnx-offline-tts \
|
||||||
|
--vits-model=vits/exp/vits-epoch-1000.onnx \
|
||||||
|
--vits-tokens=data/tokens.txt \
|
||||||
|
--vits-data-dir=/tmp/espeak-ng-data \
|
||||||
|
--num-threads=1 \
|
||||||
|
--output-filename=./high.wav \
|
||||||
|
"Ask not what your country can do for you; ask what you can do for your country."
|
||||||
|
|
||||||
|
.. hint::
|
||||||
|
|
||||||
|
You can also use ``sherpa-onnx-offline-tts-play`` to play the audio
|
||||||
|
as it is generating.
|
||||||
|
|
||||||
|
You should get a file ``high.wav`` after running the above command.
|
||||||
|
|
||||||
|
Congratulations! You have successfully trained and exported a text-to-speech
|
||||||
|
model and run it with `sherpa-onnx`_.
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
VITS
|
VITS-VCTK
|
||||||
===============
|
===============
|
||||||
|
|
||||||
This tutorial shows you how to train an VITS model
|
This tutorial shows you how to train an VITS model
|
||||||
|
@ -19,7 +19,9 @@ The following table lists the differences among them.
|
|||||||
| `transducer_stateless_modified` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` |
|
| `transducer_stateless_modified` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` |
|
||||||
| `transducer_stateless_modified-2` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` + extra data |
|
| `transducer_stateless_modified-2` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` + extra data |
|
||||||
| `pruned_transducer_stateless3` | Conformer (reworked) | Embedding + Conv1d | pruned RNN-T + reworked model with random combiner + using aidatatang_20zh as extra data|
|
| `pruned_transducer_stateless3` | Conformer (reworked) | Embedding + Conv1d | pruned RNN-T + reworked model with random combiner + using aidatatang_20zh as extra data|
|
||||||
| `pruned_transducer_stateless7` | Zipformer | Embedding | pruned RNN-T + zipformer encoder + stateless decoder with context-size 1 |
|
| `pruned_transducer_stateless7` | Zipformer | Embedding | pruned RNN-T + zipformer encoder + stateless decoder with context-size set to 1 |
|
||||||
|
| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe with context-size set to 1 |
|
||||||
|
|
||||||
|
|
||||||
The decoder in `transducer_stateless` is modified from the paper
|
The decoder in `transducer_stateless` is modified from the paper
|
||||||
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
|
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
|
||||||
|
@ -360,7 +360,7 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
||||||
log "Stage 11: Train RNN LM model"
|
log "Stage 12: Train RNN LM model"
|
||||||
python ../../../icefall/rnn_lm/train.py \
|
python ../../../icefall/rnn_lm/train.py \
|
||||||
--start-epoch 0 \
|
--start-epoch 0 \
|
||||||
--world-size 1 \
|
--world-size 1 \
|
||||||
|
@ -89,6 +89,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -881,9 +882,7 @@ def train_one_epoch(
|
|||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error()
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
||||||
|
@ -85,6 +85,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
@ -878,9 +879,7 @@ def train_one_epoch(
|
|||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
||||||
|
@ -78,6 +78,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -871,9 +872,7 @@ def train_one_epoch(
|
|||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
|
@ -250,7 +250,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -78,6 +78,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
@ -882,9 +883,7 @@ def train_one_epoch(
|
|||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
|
@ -78,6 +78,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
@ -881,9 +882,7 @@ def train_one_epoch(
|
|||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
|
@ -86,6 +86,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
@ -985,9 +986,7 @@ def train_one_epoch(
|
|||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
save_bad_model()
|
save_bad_model()
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = max(scheduler.get_last_lr())
|
cur_lr = max(scheduler.get_last_lr())
|
||||||
|
@ -83,6 +83,7 @@ from icefall.checkpoint import (
|
|||||||
update_averaged_model,
|
update_averaged_model,
|
||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -570,9 +571,7 @@ def train_one_epoch(
|
|||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
save_bad_model()
|
save_bad_model()
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = max(scheduler.get_last_lr())
|
cur_lr = max(scheduler.get_last_lr())
|
||||||
|
@ -29,7 +29,14 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
from lhotse import (
|
||||||
|
CutSet,
|
||||||
|
Fbank,
|
||||||
|
FbankConfig,
|
||||||
|
LilcomChunkyWriter,
|
||||||
|
WhisperFbank,
|
||||||
|
WhisperFbankConfig,
|
||||||
|
)
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor, str2bool
|
from icefall.utils import get_executor, str2bool
|
||||||
@ -42,10 +49,12 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
|
def compute_fbank_aishell2(
|
||||||
|
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
|
||||||
|
):
|
||||||
src_dir = Path("data/manifests")
|
src_dir = Path("data/manifests")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(8, os.cpu_count())
|
||||||
|
|
||||||
dataset_parts = (
|
dataset_parts = (
|
||||||
"train",
|
"train",
|
||||||
@ -68,8 +77,12 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
|
|||||||
list(manifests.keys()),
|
list(manifests.keys()),
|
||||||
dataset_parts,
|
dataset_parts,
|
||||||
)
|
)
|
||||||
|
if whisper_fbank:
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
extractor = WhisperFbank(
|
||||||
|
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
with get_executor() as ex: # Initialize the executor only once.
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
for partition, m in manifests.items():
|
for partition, m in manifests.items():
|
||||||
@ -82,7 +95,7 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
|
|||||||
supervisions=m["supervisions"],
|
supervisions=m["supervisions"],
|
||||||
)
|
)
|
||||||
if "train" in partition and perturb_speed:
|
if "train" in partition and perturb_speed:
|
||||||
logging.info(f"Doing speed perturb")
|
logging.info("Doing speed perturb")
|
||||||
cut_set = (
|
cut_set = (
|
||||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
@ -111,7 +124,12 @@ def get_args():
|
|||||||
default=False,
|
default=False,
|
||||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-fbank",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -122,5 +140,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_aishell2(
|
compute_fbank_aishell2(
|
||||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
num_mel_bins=args.num_mel_bins,
|
||||||
|
perturb_speed=args.perturb_speed,
|
||||||
|
whisper_fbank=args.whisper_fbank,
|
||||||
)
|
)
|
||||||
|
@ -108,6 +108,16 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
|||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
whisper_mel_bins=80
|
||||||
|
if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then
|
||||||
|
log "Stage 30: Compute whisper fbank for aishell2"
|
||||||
|
if [ ! -f data/fbank/.aishell2.whisper.done ]; then
|
||||||
|
mkdir -p data/fbank
|
||||||
|
./local/compute_fbank_aishell2.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||||
|
touch data/fbank/.aishell2.whisper.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
log "Stage 4: Compute fbank for musan"
|
log "Stage 4: Compute fbank for musan"
|
||||||
if [ ! -f data/fbank/.msuan.done ]; then
|
if [ ! -f data/fbank/.msuan.done ]; then
|
||||||
|
@ -29,7 +29,14 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig
|
from lhotse import (
|
||||||
|
CutSet,
|
||||||
|
Fbank,
|
||||||
|
FbankConfig,
|
||||||
|
LilcomChunkyWriter,
|
||||||
|
WhisperFbank,
|
||||||
|
WhisperFbankConfig,
|
||||||
|
)
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor, str2bool
|
from icefall.utils import get_executor, str2bool
|
||||||
@ -42,10 +49,12 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
|
def compute_fbank_aishell4(
|
||||||
|
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
|
||||||
|
):
|
||||||
src_dir = Path("data/manifests/aishell4")
|
src_dir = Path("data/manifests/aishell4")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(8, os.cpu_count())
|
||||||
|
|
||||||
dataset_parts = (
|
dataset_parts = (
|
||||||
"train_S",
|
"train_S",
|
||||||
@ -70,7 +79,12 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
|
|||||||
dataset_parts,
|
dataset_parts,
|
||||||
)
|
)
|
||||||
|
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
if whisper_fbank:
|
||||||
|
extractor = WhisperFbank(
|
||||||
|
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
with get_executor() as ex: # Initialize the executor only once.
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
for partition, m in manifests.items():
|
for partition, m in manifests.items():
|
||||||
@ -84,7 +98,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
|
|||||||
supervisions=m["supervisions"],
|
supervisions=m["supervisions"],
|
||||||
)
|
)
|
||||||
if "train" in partition and perturb_speed:
|
if "train" in partition and perturb_speed:
|
||||||
logging.info(f"Doing speed perturb")
|
logging.info("Doing speed perturb")
|
||||||
cut_set = (
|
cut_set = (
|
||||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
@ -95,7 +109,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
|
|||||||
# when an executor is specified, make more partitions
|
# when an executor is specified, make more partitions
|
||||||
num_jobs=num_jobs if ex is None else 80,
|
num_jobs=num_jobs if ex is None else 80,
|
||||||
executor=ex,
|
executor=ex,
|
||||||
storage_type=ChunkedLilcomHdf5Writer,
|
storage_type=LilcomChunkyWriter,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("About splitting cuts into smaller chunks")
|
logging.info("About splitting cuts into smaller chunks")
|
||||||
@ -121,7 +135,12 @@ def get_args():
|
|||||||
default=False,
|
default=False,
|
||||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-fbank",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -132,5 +151,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_aishell4(
|
compute_fbank_aishell4(
|
||||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
num_mel_bins=args.num_mel_bins,
|
||||||
|
perturb_speed=args.perturb_speed,
|
||||||
|
whisper_fbank=args.whisper_fbank,
|
||||||
)
|
)
|
||||||
|
@ -6,7 +6,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
|||||||
set -eou pipefail
|
set -eou pipefail
|
||||||
|
|
||||||
stage=-1
|
stage=-1
|
||||||
stop_stage=100
|
stop_stage=7
|
||||||
perturb_speed=true
|
perturb_speed=true
|
||||||
|
|
||||||
|
|
||||||
@ -76,11 +76,21 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||||
log "Stage 2: Process aishell4"
|
log "Stage 2: Compute fbank for aishell4"
|
||||||
if [ ! -f data/fbank/aishell4/.fbank.done ]; then
|
if [ ! -f data/fbank/aishell4/.fbank.done ]; then
|
||||||
mkdir -p data/fbank/aishell4
|
mkdir -p data/fbank
|
||||||
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed}
|
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed}
|
||||||
touch data/fbank/aishell4/.fbank.done
|
touch data/fbank/.fbank.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
whisper_mel_bins=80
|
||||||
|
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
|
||||||
|
log "Stage 20: Compute whisper fbank for aishell4"
|
||||||
|
if [ ! -f data/fbank/aishell4/.fbank.done ]; then
|
||||||
|
mkdir -p data/fbank
|
||||||
|
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||||
|
touch data/fbank/.fbank.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@ -106,16 +116,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
log "Stage 5: Compute fbank for aishell4"
|
log "Stage 5: Prepare char based lang"
|
||||||
if [ ! -f data/fbank/.aishell4.done ]; then
|
|
||||||
mkdir -p data/fbank
|
|
||||||
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed}
|
|
||||||
touch data/fbank/.aishell4.done
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
|
||||||
log "Stage 6: Prepare char based lang"
|
|
||||||
lang_char_dir=data/lang_char
|
lang_char_dir=data/lang_char
|
||||||
mkdir -p $lang_char_dir
|
mkdir -p $lang_char_dir
|
||||||
|
|
||||||
|
@ -29,7 +29,14 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
from lhotse import (
|
||||||
|
CutSet,
|
||||||
|
Fbank,
|
||||||
|
FbankConfig,
|
||||||
|
LilcomChunkyWriter,
|
||||||
|
WhisperFbank,
|
||||||
|
WhisperFbankConfig,
|
||||||
|
)
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor, str2bool
|
from icefall.utils import get_executor, str2bool
|
||||||
@ -42,10 +49,12 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False):
|
def compute_fbank_alimeeting(
|
||||||
|
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
|
||||||
|
):
|
||||||
src_dir = Path("data/manifests/alimeeting")
|
src_dir = Path("data/manifests/alimeeting")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(8, os.cpu_count())
|
||||||
|
|
||||||
dataset_parts = (
|
dataset_parts = (
|
||||||
"train",
|
"train",
|
||||||
@ -53,7 +62,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
|
|||||||
"test",
|
"test",
|
||||||
)
|
)
|
||||||
|
|
||||||
prefix = "alimeeting"
|
prefix = "alimeeting-far"
|
||||||
suffix = "jsonl.gz"
|
suffix = "jsonl.gz"
|
||||||
manifests = read_manifests_if_cached(
|
manifests = read_manifests_if_cached(
|
||||||
dataset_parts=dataset_parts,
|
dataset_parts=dataset_parts,
|
||||||
@ -70,7 +79,12 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
|
|||||||
dataset_parts,
|
dataset_parts,
|
||||||
)
|
)
|
||||||
|
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
if whisper_fbank:
|
||||||
|
extractor = WhisperFbank(
|
||||||
|
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
with get_executor() as ex: # Initialize the executor only once.
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
for partition, m in manifests.items():
|
for partition, m in manifests.items():
|
||||||
@ -83,7 +97,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
|
|||||||
supervisions=m["supervisions"],
|
supervisions=m["supervisions"],
|
||||||
)
|
)
|
||||||
if "train" in partition and perturb_speed:
|
if "train" in partition and perturb_speed:
|
||||||
logging.info(f"Doing speed perturb")
|
logging.info("Doing speed perturb")
|
||||||
cut_set = (
|
cut_set = (
|
||||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
@ -121,7 +135,12 @@ def get_args():
|
|||||||
default=False,
|
default=False,
|
||||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-fbank",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Use the Whisper Fbank feature extractor. Default: False.",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -132,5 +151,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_alimeeting(
|
compute_fbank_alimeeting(
|
||||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
num_mel_bins=args.num_mel_bins,
|
||||||
|
perturb_speed=args.perturb_speed,
|
||||||
|
whisper_fbank=args.whisper_fbank,
|
||||||
)
|
)
|
||||||
|
@ -6,7 +6,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
|||||||
set -eou pipefail
|
set -eou pipefail
|
||||||
|
|
||||||
stage=-1
|
stage=-1
|
||||||
stop_stage=100
|
stop_stage=7
|
||||||
perturb_speed=true
|
perturb_speed=true
|
||||||
|
|
||||||
# We assume dl_dir (download dir) contains the following
|
# We assume dl_dir (download dir) contains the following
|
||||||
@ -15,7 +15,7 @@ perturb_speed=true
|
|||||||
#
|
#
|
||||||
# - $dl_dir/alimeeting
|
# - $dl_dir/alimeeting
|
||||||
# This directory contains the following files downloaded from
|
# This directory contains the following files downloaded from
|
||||||
# https://openslr.org/62/
|
# https://openslr.org/119/
|
||||||
#
|
#
|
||||||
# - Train_Ali_far.tar.gz
|
# - Train_Ali_far.tar.gz
|
||||||
# - Train_Ali_near.tar.gz
|
# - Train_Ali_near.tar.gz
|
||||||
@ -66,10 +66,21 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||||
log "Stage 2: Process alimeeting"
|
log "Stage 2: compute fbank for alimeeting"
|
||||||
if [ ! -f data/fbank/alimeeting/.fbank.done ]; then
|
if [ ! -f data/fbank/.fbank.done ]; then
|
||||||
mkdir -p data/fbank/alimeeting
|
mkdir -p data/fbank
|
||||||
./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed}
|
./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed}
|
||||||
|
touch data/fbank/.fbank.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
whisper_mel_bins=80
|
||||||
|
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
|
||||||
|
log "Stage 20: compute whisper fbank for alimeeting"
|
||||||
|
if [ ! -f data/fbank/.fbank.done ]; then
|
||||||
|
mkdir -p data/fbank
|
||||||
|
./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||||
|
touch data/fbank/.fbank.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@ -95,16 +106,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
log "Stage 5: Compute fbank for alimeeting"
|
log "Stage 5: Prepare char based lang"
|
||||||
if [ ! -f data/fbank/.alimeeting.done ]; then
|
|
||||||
mkdir -p data/fbank
|
|
||||||
./local/compute_fbank_alimeeting.py --perturb-speed True
|
|
||||||
touch data/fbank/.alimeeting.done
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
|
||||||
log "Stage 6: Prepare char based lang"
|
|
||||||
lang_char_dir=data/lang_char
|
lang_char_dir=data/lang_char
|
||||||
mkdir -p $lang_char_dir
|
mkdir -p $lang_char_dir
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ use_gss=true # Use GSS-based enhancement with MDM setting
|
|||||||
#
|
#
|
||||||
# - $dl_dir/alimeeting
|
# - $dl_dir/alimeeting
|
||||||
# This directory contains the following files downloaded from
|
# This directory contains the following files downloaded from
|
||||||
# https://openslr.org/62/
|
# https://openslr.org/119/
|
||||||
#
|
#
|
||||||
# - Train_Ali_far.tar.gz
|
# - Train_Ali_far.tar.gz
|
||||||
# - Train_Ali_near.tar.gz
|
# - Train_Ali_near.tar.gz
|
||||||
|
@ -70,6 +70,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
@ -851,9 +852,7 @@ def train_one_epoch(
|
|||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
|
@ -69,6 +69,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
|
|
||||||
@ -842,9 +843,7 @@ def train_one_epoch(
|
|||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
|
@ -75,6 +75,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -1138,9 +1139,7 @@ def train_one_epoch(
|
|||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
|
@ -75,6 +75,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -1129,9 +1130,7 @@ def train_one_epoch(
|
|||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
|
@ -1 +0,0 @@
|
|||||||
../../../librispeech/ASR/local/compile_hlg.py
|
|
168
egs/commonvoice/ASR/local/compile_hlg.py
Executable file
168
egs/commonvoice/ASR/local/compile_hlg.py
Executable file
@ -0,0 +1,168 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
|
# Zengrui Jin,)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script takes as input lang_dir and generates HLG from
|
||||||
|
|
||||||
|
- H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
|
||||||
|
- L, the lexicon, built from lang_dir/L_disambig.pt
|
||||||
|
|
||||||
|
Caution: We use a lexicon that contains disambiguation symbols
|
||||||
|
|
||||||
|
- G, the LM, built from data/lm/G_n_gram.fst.txt
|
||||||
|
|
||||||
|
The generated HLG is saved in $lang_dir/HLG.pt
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm",
|
||||||
|
type=str,
|
||||||
|
default="G_3_gram",
|
||||||
|
help="""Stem name for LM used in HLG compiling.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
help="""Input and output directory.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
lang_dir:
|
||||||
|
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
|
||||||
|
lm:
|
||||||
|
The language stem base name.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
An FSA representing HLG.
|
||||||
|
"""
|
||||||
|
lexicon = Lexicon(lang_dir)
|
||||||
|
max_token_id = max(lexicon.tokens)
|
||||||
|
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
|
||||||
|
H = k2.ctc_topo(max_token_id)
|
||||||
|
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||||
|
|
||||||
|
if Path(f"{lang_dir}/lm/{lm}.pt").is_file():
|
||||||
|
logging.info(f"Loading pre-compiled {lm}")
|
||||||
|
d = torch.load(f"{lang_dir}/lm/{lm}.pt")
|
||||||
|
G = k2.Fsa.from_dict(d)
|
||||||
|
else:
|
||||||
|
logging.info(f"Loading {lm}.fst.txt")
|
||||||
|
with open(f"{lang_dir}/lm/{lm}.fst.txt") as f:
|
||||||
|
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||||
|
torch.save(G.as_dict(), f"{lang_dir}/lm/{lm}.pt")
|
||||||
|
|
||||||
|
first_token_disambig_id = lexicon.token_table["#0"]
|
||||||
|
first_word_disambig_id = lexicon.word_table["#0"]
|
||||||
|
|
||||||
|
L = k2.arc_sort(L)
|
||||||
|
G = k2.arc_sort(G)
|
||||||
|
|
||||||
|
logging.info("Intersecting L and G")
|
||||||
|
LG = k2.compose(L, G)
|
||||||
|
logging.info(f"LG shape: {LG.shape}")
|
||||||
|
|
||||||
|
logging.info("Connecting LG")
|
||||||
|
LG = k2.connect(LG)
|
||||||
|
logging.info(f"LG shape after k2.connect: {LG.shape}")
|
||||||
|
|
||||||
|
logging.info(type(LG.aux_labels))
|
||||||
|
logging.info("Determinizing LG")
|
||||||
|
|
||||||
|
LG = k2.determinize(LG)
|
||||||
|
logging.info(type(LG.aux_labels))
|
||||||
|
|
||||||
|
logging.info("Connecting LG after k2.determinize")
|
||||||
|
LG = k2.connect(LG)
|
||||||
|
|
||||||
|
logging.info("Removing disambiguation symbols on LG")
|
||||||
|
|
||||||
|
# LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||||
|
# see https://github.com/k2-fsa/k2/pull/1140
|
||||||
|
labels = LG.labels
|
||||||
|
labels[labels >= first_token_disambig_id] = 0
|
||||||
|
LG.labels = labels
|
||||||
|
|
||||||
|
assert isinstance(LG.aux_labels, k2.RaggedTensor)
|
||||||
|
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
|
||||||
|
|
||||||
|
LG = k2.remove_epsilon(LG)
|
||||||
|
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
|
||||||
|
|
||||||
|
LG = k2.connect(LG)
|
||||||
|
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
|
||||||
|
|
||||||
|
logging.info("Arc sorting LG")
|
||||||
|
LG = k2.arc_sort(LG)
|
||||||
|
|
||||||
|
logging.info("Composing H and LG")
|
||||||
|
# CAUTION: The name of the inner_labels is fixed
|
||||||
|
# to `tokens`. If you want to change it, please
|
||||||
|
# also change other places in icefall that are using
|
||||||
|
# it.
|
||||||
|
HLG = k2.compose(H, LG, inner_labels="tokens")
|
||||||
|
|
||||||
|
logging.info("Connecting LG")
|
||||||
|
HLG = k2.connect(HLG)
|
||||||
|
|
||||||
|
logging.info("Arc sorting LG")
|
||||||
|
HLG = k2.arc_sort(HLG)
|
||||||
|
logging.info(f"HLG.shape: {HLG.shape}")
|
||||||
|
|
||||||
|
return HLG
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
lang_dir = Path(args.lang_dir)
|
||||||
|
|
||||||
|
if (lang_dir / "HLG.pt").is_file():
|
||||||
|
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
logging.info(f"Processing {lang_dir}")
|
||||||
|
|
||||||
|
HLG = compile_HLG(lang_dir, args.lm)
|
||||||
|
logging.info(f"Saving HLG.pt to {lang_dir}")
|
||||||
|
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
main()
|
@ -1 +0,0 @@
|
|||||||
../../../librispeech/ASR/local/compile_lg.py
|
|
149
egs/commonvoice/ASR/local/compile_lg.py
Executable file
149
egs/commonvoice/ASR/local/compile_lg.py
Executable file
@ -0,0 +1,149 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
|
# Kang Wei,
|
||||||
|
# Zengrui Jin,)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script takes as input lang_dir and generates LG from
|
||||||
|
|
||||||
|
- L, the lexicon, built from lang_dir/L_disambig.pt
|
||||||
|
|
||||||
|
Caution: We use a lexicon that contains disambiguation symbols
|
||||||
|
|
||||||
|
- G, the LM, built from lang_dir/lm/G_3_gram.fst.txt
|
||||||
|
|
||||||
|
The generated LG is saved in $lang_dir/LG.pt
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
help="""Input and output directory.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm",
|
||||||
|
type=str,
|
||||||
|
default="G_3_gram",
|
||||||
|
help="""Stem name for LM used in HLG compiling.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
lang_dir:
|
||||||
|
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
An FSA representing LG.
|
||||||
|
"""
|
||||||
|
lexicon = Lexicon(lang_dir)
|
||||||
|
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||||
|
|
||||||
|
if Path(f"{lang_dir}/lm/{lm}.pt").is_file():
|
||||||
|
logging.info(f"Loading pre-compiled {lm}")
|
||||||
|
d = torch.load(f"{lang_dir}/lm/{lm}.pt")
|
||||||
|
G = k2.Fsa.from_dict(d)
|
||||||
|
else:
|
||||||
|
logging.info(f"Loading {lm}.fst.txt")
|
||||||
|
with open(f"{lang_dir}/lm/{lm}.fst.txt") as f:
|
||||||
|
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||||
|
torch.save(G.as_dict(), f"{lang_dir}/lm/{lm}.pt")
|
||||||
|
|
||||||
|
first_token_disambig_id = lexicon.token_table["#0"]
|
||||||
|
first_word_disambig_id = lexicon.word_table["#0"]
|
||||||
|
|
||||||
|
L = k2.arc_sort(L)
|
||||||
|
G = k2.arc_sort(G)
|
||||||
|
|
||||||
|
logging.info("Intersecting L and G")
|
||||||
|
LG = k2.compose(L, G)
|
||||||
|
logging.info(f"LG shape: {LG.shape}")
|
||||||
|
|
||||||
|
logging.info("Connecting LG")
|
||||||
|
LG = k2.connect(LG)
|
||||||
|
logging.info(f"LG shape after k2.connect: {LG.shape}")
|
||||||
|
|
||||||
|
logging.info(type(LG.aux_labels))
|
||||||
|
logging.info("Determinizing LG")
|
||||||
|
|
||||||
|
LG = k2.determinize(LG, k2.DeterminizeWeightPushingType.kLogWeightPushing)
|
||||||
|
logging.info(type(LG.aux_labels))
|
||||||
|
|
||||||
|
logging.info("Connecting LG after k2.determinize")
|
||||||
|
LG = k2.connect(LG)
|
||||||
|
|
||||||
|
logging.info("Removing disambiguation symbols on LG")
|
||||||
|
|
||||||
|
# LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||||
|
# see https://github.com/k2-fsa/k2/pull/1140
|
||||||
|
labels = LG.labels
|
||||||
|
labels[labels >= first_token_disambig_id] = 0
|
||||||
|
LG.labels = labels
|
||||||
|
|
||||||
|
assert isinstance(LG.aux_labels, k2.RaggedTensor)
|
||||||
|
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
|
||||||
|
|
||||||
|
LG = k2.remove_epsilon(LG)
|
||||||
|
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
|
||||||
|
|
||||||
|
LG = k2.connect(LG)
|
||||||
|
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
|
||||||
|
|
||||||
|
logging.info("Arc sorting LG")
|
||||||
|
LG = k2.arc_sort(LG)
|
||||||
|
|
||||||
|
return LG
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
lang_dir = Path(args.lang_dir)
|
||||||
|
|
||||||
|
if (lang_dir / "LG.pt").is_file():
|
||||||
|
logging.info(f"{lang_dir}/LG.pt already exists - skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
logging.info(f"Processing {lang_dir}")
|
||||||
|
|
||||||
|
LG = compile_LG(lang_dir, args.lm)
|
||||||
|
logging.info(f"Saving LG.pt to {lang_dir}")
|
||||||
|
torch.save(LG.as_dict(), f"{lang_dir}/LG.pt")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
main()
|
@ -48,8 +48,27 @@ def normalize_text(utt: str, language: str) -> str:
|
|||||||
utt = re.sub("’", "'", utt)
|
utt = re.sub("’", "'", utt)
|
||||||
if language == "en":
|
if language == "en":
|
||||||
return re.sub(r"[^a-zA-Z\s]", "", utt).upper()
|
return re.sub(r"[^a-zA-Z\s]", "", utt).upper()
|
||||||
if language == "fr":
|
elif language == "fr":
|
||||||
return re.sub(r"[^A-ZÀÂÆÇÉÈÊËÎÏÔŒÙÛÜ' ]", "", utt).upper()
|
return re.sub(r"[^A-ZÀÂÆÇÉÈÊËÎÏÔŒÙÛÜ' ]", "", utt).upper()
|
||||||
|
elif language == "pl":
|
||||||
|
return re.sub(r"[^a-ząćęłńóśźżA-ZĄĆĘŁŃÓŚŹŻ' ]", "", utt).upper()
|
||||||
|
elif language == "yue":
|
||||||
|
return (
|
||||||
|
utt.replace(" ", "")
|
||||||
|
.replace(",", "")
|
||||||
|
.replace("。", " ")
|
||||||
|
.replace("?", "")
|
||||||
|
.replace("!", "")
|
||||||
|
.replace("?", "")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"""
|
||||||
|
Text normalization not implemented for language: {language},
|
||||||
|
please consider implementing it in the local/preprocess_commonvoice.py
|
||||||
|
or raise an issue on GitHub to request it.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def preprocess_commonvoice(
|
def preprocess_commonvoice(
|
||||||
|
@ -381,9 +381,11 @@ class CommonVoiceAsrDataModule:
|
|||||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||||
logging.debug("About to create test dataset")
|
logging.debug("About to create test dataset")
|
||||||
test = K2SpeechRecognitionDataset(
|
test = K2SpeechRecognitionDataset(
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
input_strategy=(
|
||||||
if self.args.on_the_fly_feats
|
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||||
else eval(self.args.input_strategy)(),
|
if self.args.on_the_fly_feats
|
||||||
|
else eval(self.args.input_strategy)()
|
||||||
|
),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
sampler = DynamicBucketingSampler(
|
sampler = DynamicBucketingSampler(
|
||||||
|
@ -79,10 +79,10 @@ It will generate the following 3 files inside $repo/exp:
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from icefall import is_module_available
|
import torch
|
||||||
from onnx_pretrained import OnnxModel
|
from onnx_pretrained import OnnxModel
|
||||||
|
|
||||||
import torch
|
from icefall import is_module_available
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -79,6 +79,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -871,9 +872,7 @@ def train_one_epoch(
|
|||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
|
@ -31,7 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
|||||||
DynamicBucketingSampler,
|
DynamicBucketingSampler,
|
||||||
K2SpeechRecognitionDataset,
|
K2SpeechRecognitionDataset,
|
||||||
PrecomputedFeatures,
|
PrecomputedFeatures,
|
||||||
SingleCutSampler,
|
SimpleCutSampler,
|
||||||
SpecAugment,
|
SpecAugment,
|
||||||
)
|
)
|
||||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||||
@ -232,7 +232,7 @@ class CommonVoiceAsrDataModule:
|
|||||||
logging.info("About to get Musan cuts")
|
logging.info("About to get Musan cuts")
|
||||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||||
transforms.append(
|
transforms.append(
|
||||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Disable MUSAN")
|
logging.info("Disable MUSAN")
|
||||||
@ -315,8 +315,8 @@ class CommonVoiceAsrDataModule:
|
|||||||
drop_last=self.args.drop_last,
|
drop_last=self.args.drop_last,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SingleCutSampler.")
|
logging.info("Using SimpleCutSampler.")
|
||||||
train_sampler = SingleCutSampler(
|
train_sampler = SimpleCutSampler(
|
||||||
cuts_train,
|
cuts_train,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
@ -383,9 +383,11 @@ class CommonVoiceAsrDataModule:
|
|||||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||||
logging.debug("About to create test dataset")
|
logging.debug("About to create test dataset")
|
||||||
test = K2SpeechRecognitionDataset(
|
test = K2SpeechRecognitionDataset(
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
input_strategy=(
|
||||||
if self.args.on_the_fly_feats
|
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||||
else eval(self.args.input_strategy)(),
|
if self.args.on_the_fly_feats
|
||||||
|
else eval(self.args.input_strategy)()
|
||||||
|
),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
sampler = DynamicBucketingSampler(
|
sampler = DynamicBucketingSampler(
|
||||||
|
@ -889,9 +889,7 @@ def train_one_epoch(
|
|||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f", exiting: {cur_grad_scale}")
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
|
@ -81,6 +81,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -965,9 +966,7 @@ def train_one_epoch(
|
|||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
|
@ -78,6 +78,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
|
|
||||||
@ -888,9 +889,7 @@ def train_one_epoch(
|
|||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
|
@ -81,6 +81,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
|
|
||||||
@ -909,9 +910,7 @@ def train_one_epoch(
|
|||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
|
@ -70,9 +70,9 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from tokenizer import Tokenizer
|
from tokenizer import Tokenizer
|
||||||
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
|
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
|
@ -81,6 +81,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
|
|
||||||
@ -908,9 +909,7 @@ def train_one_epoch(
|
|||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
|
@ -23,6 +23,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
from lhotse import CutSet, SupervisionSegment
|
from lhotse import CutSet, SupervisionSegment
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
# Similar text filtering and normalization procedure as in:
|
# Similar text filtering and normalization procedure as in:
|
||||||
|
@ -76,6 +76,7 @@ from beam_search import (
|
|||||||
)
|
)
|
||||||
from gigaspeech_scoring import asr_text_post_processing
|
from gigaspeech_scoring import asr_text_post_processing
|
||||||
from train import get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
average_checkpoints_with_averaged_model,
|
average_checkpoints_with_averaged_model,
|
||||||
|
@ -88,7 +88,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import GigaSpeechAsrDataModule
|
from asr_datamodule import GigaSpeechAsrDataModule
|
||||||
from train import add_model_arguments, get_params, get_model
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
|
@ -51,7 +51,7 @@ from streaming_beam_search import (
|
|||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_model
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
|
@ -89,6 +89,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -1031,9 +1032,7 @@ def train_one_epoch(
|
|||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
save_bad_model()
|
save_bad_model()
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = max(scheduler.get_last_lr())
|
cur_lr = max(scheduler.get_last_lr())
|
||||||
|
@ -42,12 +42,10 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import GigaSpeechAsrDataModule
|
from asr_datamodule import GigaSpeechAsrDataModule
|
||||||
from beam_search import (
|
from beam_search import keywords_search
|
||||||
keywords_search,
|
from lhotse.cut import Cut
|
||||||
)
|
|
||||||
from train import add_model_arguments, get_model, get_params
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from lhotse.cut import Cut
|
|
||||||
from icefall import ContextGraph
|
from icefall import ContextGraph
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
|
@ -76,25 +76,6 @@ from torch import Tensor
|
|||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from icefall import diagnostics
|
|
||||||
from icefall.checkpoint import remove_checkpoints
|
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
|
||||||
from icefall.checkpoint import (
|
|
||||||
save_checkpoint_with_global_batch_idx,
|
|
||||||
update_averaged_model,
|
|
||||||
)
|
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
|
||||||
from icefall.env import get_env_info
|
|
||||||
from icefall.hooks import register_inf_check_hooks
|
|
||||||
from icefall.utils import (
|
|
||||||
AttributeDict,
|
|
||||||
MetricsTracker,
|
|
||||||
get_parameter_groups_with_lrs,
|
|
||||||
setup_logger,
|
|
||||||
str2bool,
|
|
||||||
)
|
|
||||||
|
|
||||||
from train import (
|
from train import (
|
||||||
add_model_arguments,
|
add_model_arguments,
|
||||||
add_training_arguments,
|
add_training_arguments,
|
||||||
@ -110,6 +91,25 @@ from train import (
|
|||||||
set_batch_count,
|
set_batch_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from icefall import diagnostics
|
||||||
|
from icefall.checkpoint import remove_checkpoints
|
||||||
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
save_checkpoint_with_global_batch_idx,
|
||||||
|
update_averaged_model,
|
||||||
|
)
|
||||||
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
|
from icefall.hooks import register_inf_check_hooks
|
||||||
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
MetricsTracker,
|
||||||
|
get_parameter_groups_with_lrs,
|
||||||
|
setup_logger,
|
||||||
|
str2bool,
|
||||||
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
|
|
||||||
@ -372,9 +372,7 @@ def train_one_epoch(
|
|||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
save_bad_model()
|
save_bad_model()
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = max(scheduler.get_last_lr())
|
cur_lr = max(scheduler.get_last_lr())
|
||||||
|
@ -89,6 +89,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -1034,9 +1035,7 @@ def train_one_epoch(
|
|||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
save_bad_model()
|
save_bad_model()
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = max(scheduler.get_last_lr())
|
cur_lr = max(scheduler.get_last_lr())
|
||||||
|
@ -85,6 +85,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -1169,9 +1170,7 @@ def train_one_epoch(
|
|||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
|
@ -81,6 +81,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -1056,9 +1057,7 @@ def train_one_epoch(
|
|||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
|
@ -232,7 +232,7 @@ class LibriHeavyAsrDataModule:
|
|||||||
logging.info("About to get Musan cuts")
|
logging.info("About to get Musan cuts")
|
||||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||||
transforms.append(
|
transforms.append(
|
||||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Disable MUSAN")
|
logging.info("Disable MUSAN")
|
||||||
|
@ -93,6 +93,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -1036,9 +1037,7 @@ def train_one_epoch(
|
|||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
save_bad_model()
|
save_bad_model()
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = max(scheduler.get_last_lr())
|
cur_lr = max(scheduler.get_last_lr())
|
||||||
|
@ -425,9 +425,11 @@ class LibriHeavyAsrDataModule:
|
|||||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||||
logging.debug("About to create test dataset")
|
logging.debug("About to create test dataset")
|
||||||
test = K2SpeechRecognitionDataset(
|
test = K2SpeechRecognitionDataset(
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
input_strategy=(
|
||||||
if self.args.on_the_fly_feats
|
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||||
else PrecomputedFeatures(),
|
if self.args.on_the_fly_feats
|
||||||
|
else PrecomputedFeatures()
|
||||||
|
),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
sampler = DynamicBucketingSampler(
|
sampler = DynamicBucketingSampler(
|
||||||
|
@ -103,6 +103,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -1051,9 +1052,7 @@ def train_one_epoch(
|
|||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
save_bad_model()
|
save_bad_model()
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = max(scheduler.get_last_lr())
|
cur_lr = max(scheduler.get_last_lr())
|
||||||
|
@ -117,6 +117,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -855,9 +856,9 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|||||||
decoder_dim=params.decoder_dim,
|
decoder_dim=params.decoder_dim,
|
||||||
joiner_dim=params.joiner_dim,
|
joiner_dim=params.joiner_dim,
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
context_dim=4 * 768
|
context_dim=(
|
||||||
if params.context_injection
|
4 * 768 if params.context_injection else -1
|
||||||
else -1, # the output dim of text encoder
|
), # the output dim of text encoder
|
||||||
context_injection=params.context_injection,
|
context_injection=params.context_injection,
|
||||||
)
|
)
|
||||||
return joiner
|
return joiner
|
||||||
@ -1398,9 +1399,7 @@ def train_one_epoch(
|
|||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
save_bad_model()
|
save_bad_model()
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = max(scheduler.get_last_lr())
|
cur_lr = max(scheduler.get_last_lr())
|
||||||
|
@ -35,8 +35,7 @@ The following table lists the differences among them.
|
|||||||
| `lstm_transducer_stateless2` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gigaspeech (multi-dataset setup) |
|
| `lstm_transducer_stateless2` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gigaspeech (multi-dataset setup) |
|
||||||
| `lstm_transducer_stateless3` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gradient filter + delay penalty |
|
| `lstm_transducer_stateless3` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gradient filter + delay penalty |
|
||||||
| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe |
|
| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe |
|
||||||
| `zipformer_adapter` | Upgraded Zipformer | Embedding + Conv1d | Finetune `zipformer` with adapter |
|
| `zipformer_adapter` | Upgraded Zipformer | Embedding + Conv1d | It supports domain adaptation of Zipformer using parameter efficient adapters |
|
||||||
| `zipformer_lora` | Upgraded Zipformer | Embedding + Conv1d | Finetune `zipformer` with LoRA |
|
|
||||||
|
|
||||||
The decoder in `transducer_stateless` is modified from the paper
|
The decoder in `transducer_stateless` is modified from the paper
|
||||||
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
|
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
|
||||||
|
@ -24,8 +24,7 @@ To run this file, do:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from train import get_ctc_model, get_params
|
||||||
from train import get_params, get_ctc_model
|
|
||||||
|
|
||||||
|
|
||||||
def test_model():
|
def test_model():
|
||||||
|
@ -59,9 +59,9 @@ import onnx
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
|
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
|
||||||
from emformer import Emformer
|
from emformer import Emformer
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
|
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
|
@ -39,7 +39,7 @@ Usage of this script:
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
@ -47,7 +47,6 @@ import torch
|
|||||||
import torchaudio
|
import torchaudio
|
||||||
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from typing import Optional, List
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -31,28 +31,28 @@ https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stat
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import torch.multiprocessing as mp
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import logging
|
import logging
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torch.nn as nn
|
||||||
from asr_datamodule import AsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
fast_beam_search_one_best,
|
fast_beam_search_one_best,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from icefall.utils import AttributeDict, convert_timestamp, setup_logger
|
|
||||||
from lhotse import CutSet, load_manifest_lazy
|
from lhotse import CutSet, load_manifest_lazy
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.supervision import AlignmentItem
|
|
||||||
from lhotse.serialization import SequentialJsonlWriter
|
from lhotse.serialization import SequentialJsonlWriter
|
||||||
|
from lhotse.supervision import AlignmentItem
|
||||||
|
|
||||||
|
from icefall.utils import AttributeDict, convert_timestamp, setup_logger
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -73,12 +73,11 @@ It will generate the following 3 files inside $repo/exp:
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
from onnx_pretrained import OnnxModel
|
from onnx_pretrained import OnnxModel
|
||||||
|
|
||||||
from icefall import is_module_available
|
from icefall import is_module_available
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
|
@ -22,11 +22,12 @@ Usage: ./pruned_transducer_stateless/my_profile.py
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
|
from train import add_model_arguments, get_encoder_model, get_params
|
||||||
|
|
||||||
from icefall.profiler import get_model_profile
|
from icefall.profiler import get_model_profile
|
||||||
from train import get_encoder_model, add_model_arguments, get_params
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -75,8 +75,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from onnx_pretrained import OnnxModel, greedy_search
|
||||||
from onnx_pretrained import greedy_search, OnnxModel
|
|
||||||
|
|
||||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||||
|
|
||||||
|
@ -78,10 +78,10 @@ It will generate the following 3 files inside $repo/exp:
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from icefall import is_module_available
|
import torch
|
||||||
from onnx_pretrained import OnnxModel
|
from onnx_pretrained import OnnxModel
|
||||||
|
|
||||||
import torch
|
from icefall import is_module_available
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -76,8 +76,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import AsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
from librispeech import LibriSpeech
|
from librispeech import LibriSpeech
|
||||||
|
from onnx_pretrained import OnnxModel, greedy_search
|
||||||
from onnx_pretrained import greedy_search, OnnxModel
|
|
||||||
|
|
||||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||||
|
|
||||||
|
@ -22,15 +22,15 @@ Usage: ./pruned_transducer_stateless4/my_profile.py
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
|
from scaling import BasicNorm, DoubleSwish
|
||||||
from typing import Tuple
|
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
from train import add_model_arguments, get_encoder_model, get_joiner_model, get_params
|
||||||
|
|
||||||
from icefall.profiler import get_model_profile
|
from icefall.profiler import get_model_profile
|
||||||
from scaling import BasicNorm, DoubleSwish
|
|
||||||
from train import get_encoder_model, get_joiner_model, add_model_arguments, get_params
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -82,8 +82,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from onnx_pretrained import OnnxModel, greedy_search
|
||||||
from onnx_pretrained import greedy_search, OnnxModel
|
|
||||||
|
|
||||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||||
|
|
||||||
|
@ -20,7 +20,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||||
|
|
||||||
# The force alignment problem can be formulated as finding
|
# The force alignment problem can be formulated as finding
|
||||||
|
@ -107,9 +107,6 @@ import k2
|
|||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
# from asr_datamodule import LibriSpeechAsrDataModule
|
|
||||||
from gigaspeech import GigaSpeechAsrDataModule
|
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
fast_beam_search_nbest,
|
fast_beam_search_nbest,
|
||||||
@ -120,6 +117,9 @@ from beam_search import (
|
|||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from gigaspeech import GigaSpeechAsrDataModule
|
||||||
from gigaspeech_scoring import asr_text_post_processing
|
from gigaspeech_scoring import asr_text_post_processing
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
@ -80,6 +80,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -976,9 +977,7 @@ def train_one_epoch(
|
|||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
|
@ -65,16 +65,15 @@ from typing import Dict, List
|
|||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.utils import str2bool
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
average_checkpoints_with_averaged_model,
|
average_checkpoints_with_averaged_model,
|
||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -22,15 +22,15 @@ Usage: ./pruned_transducer_stateless7/my_profile.py
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
|
from scaling import BasicNorm, DoubleSwish
|
||||||
from typing import Tuple
|
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
from train import add_model_arguments, get_encoder_model, get_joiner_model, get_params
|
||||||
|
|
||||||
from icefall.profiler import get_model_profile
|
from icefall.profiler import get_model_profile
|
||||||
from scaling import BasicNorm, DoubleSwish
|
|
||||||
from train import get_encoder_model, get_joiner_model, add_model_arguments, get_params
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -75,8 +75,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from onnx_pretrained import OnnxModel, greedy_search
|
||||||
from onnx_pretrained import greedy_search, OnnxModel
|
|
||||||
|
|
||||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||||
|
|
||||||
|
@ -24,7 +24,6 @@ To run this file, do:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
|
@ -81,6 +81,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -878,9 +879,7 @@ def train_one_epoch(
|
|||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
|
@ -81,6 +81,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -902,9 +903,7 @@ def train_one_epoch(
|
|||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
|
@ -118,8 +118,8 @@ from beam_search import (
|
|||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
|
@ -18,10 +18,7 @@ from typing import List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from scaling import (
|
from scaling import ActivationBalancer, ScaledConv1d
|
||||||
ActivationBalancer,
|
|
||||||
ScaledConv1d,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LConv(nn.Module):
|
class LConv(nn.Module):
|
||||||
|
@ -52,7 +52,7 @@ import onnxruntime as ort
|
|||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
|
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user