mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
resolve conflict
This commit is contained in:
commit
390f01653f
1
.github/scripts/.gitignore
vendored
Normal file
1
.github/scripts/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
piper_phonemize.html
|
29
.github/scripts/generate-piper-phonemize-page.py
vendored
Executable file
29
.github/scripts/generate-piper-phonemize-page.py
vendored
Executable file
@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
|
||||
def main():
|
||||
prefix = (
|
||||
"https://github.com/csukuangfj/piper-phonemize/releases/download/2023.12.5/"
|
||||
)
|
||||
files = [
|
||||
"piper_phonemize-1.2.0-cp310-cp310-macosx_10_14_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp311-cp311-macosx_10_14_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp312-cp312-macosx_10_14_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp37-cp37m-macosx_10_14_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp38-cp38-macosx_10_14_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp39-cp39-macosx_10_14_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
|
||||
]
|
||||
with open("piper_phonemize.html", "w") as f:
|
||||
for file in files:
|
||||
url = prefix + file
|
||||
f.write(f'<a href="{url}">{file}</a><br/>\n')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
6
.github/scripts/librispeech/ASR/run.sh
vendored
6
.github/scripts/librispeech/ASR/run.sh
vendored
@ -15,9 +15,9 @@ function prepare_data() {
|
||||
# cause OOM error for CI later.
|
||||
mkdir -p download/lm
|
||||
pushd download/lm
|
||||
wget -q http://www.openslr.org/resources/11/librispeech-vocab.txt
|
||||
wget -q http://www.openslr.org/resources/11/librispeech-lexicon.txt
|
||||
wget -q http://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz
|
||||
wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-lm-norm.txt.gz
|
||||
wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-lexicon.txt
|
||||
wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-vocab.txt
|
||||
ls -lh
|
||||
gunzip librispeech-lm-norm.txt.gz
|
||||
|
||||
|
157
.github/scripts/ljspeech/TTS/run.sh
vendored
Executable file
157
.github/scripts/ljspeech/TTS/run.sh
vendored
Executable file
@ -0,0 +1,157 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -ex
|
||||
|
||||
python3 -m pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html
|
||||
python3 -m pip install espnet_tts_frontend
|
||||
python3 -m pip install numba
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
cd egs/ljspeech/TTS
|
||||
|
||||
sed -i.bak s/600/8/g ./prepare.sh
|
||||
sed -i.bak s/"first 100"/"first 3"/g ./prepare.sh
|
||||
sed -i.bak s/500/5/g ./prepare.sh
|
||||
git diff
|
||||
|
||||
function prepare_data() {
|
||||
# We have created a subset of the data for testing
|
||||
#
|
||||
mkdir download
|
||||
pushd download
|
||||
wget -q https://huggingface.co/csukuangfj/ljspeech-subset-for-ci-test/resolve/main/LJSpeech-1.1.tar.bz2
|
||||
tar xvf LJSpeech-1.1.tar.bz2
|
||||
popd
|
||||
|
||||
./prepare.sh
|
||||
tree .
|
||||
}
|
||||
|
||||
function train() {
|
||||
pushd ./vits
|
||||
sed -i.bak s/200/3/g ./train.py
|
||||
git diff .
|
||||
popd
|
||||
|
||||
for t in low medium high; do
|
||||
./vits/train.py \
|
||||
--exp-dir vits/exp-$t \
|
||||
--model-type $t \
|
||||
--num-epochs 1 \
|
||||
--save-every-n 1 \
|
||||
--num-buckets 2 \
|
||||
--tokens data/tokens.txt \
|
||||
--max-duration 20
|
||||
|
||||
ls -lh vits/exp-$t
|
||||
done
|
||||
}
|
||||
|
||||
function infer() {
|
||||
for t in low medium high; do
|
||||
./vits/infer.py \
|
||||
--num-buckets 2 \
|
||||
--model-type $t \
|
||||
--epoch 1 \
|
||||
--exp-dir ./vits/exp-$t \
|
||||
--tokens data/tokens.txt \
|
||||
--max-duration 20
|
||||
done
|
||||
}
|
||||
|
||||
function export_onnx() {
|
||||
for t in low medium high; do
|
||||
./vits/export-onnx.py \
|
||||
--model-type $t \
|
||||
--epoch 1 \
|
||||
--exp-dir ./vits/exp-$t \
|
||||
--tokens data/tokens.txt
|
||||
|
||||
ls -lh vits/exp-$t/
|
||||
done
|
||||
}
|
||||
|
||||
function test_medium() {
|
||||
git clone https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-medium-2024-03-12
|
||||
|
||||
./vits/export-onnx.py \
|
||||
--model-type medium \
|
||||
--epoch 820 \
|
||||
--exp-dir ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp \
|
||||
--tokens ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt
|
||||
|
||||
ls -lh ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp
|
||||
|
||||
./vits/test_onnx.py \
|
||||
--model-filename ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp/vits-epoch-820.onnx \
|
||||
--tokens ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt \
|
||||
--output-filename /icefall/test-medium.wav
|
||||
|
||||
ls -lh /icefall/test-medium.wav
|
||||
|
||||
d=/icefall/vits-icefall-en_US-ljspeech-medium
|
||||
mkdir $d
|
||||
cp -v ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt $d/
|
||||
cp -v ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp/vits-epoch-820.onnx $d/model.onnx
|
||||
|
||||
rm -rf icefall-tts-ljspeech-vits-medium-2024-03-12
|
||||
|
||||
pushd $d
|
||||
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
|
||||
tar xf espeak-ng-data.tar.bz2
|
||||
rm espeak-ng-data.tar.bz2
|
||||
cd ..
|
||||
tar cjf vits-icefall-en_US-ljspeech-medium.tar.bz2 vits-icefall-en_US-ljspeech-medium
|
||||
rm -rf vits-icefall-en_US-ljspeech-medium
|
||||
ls -lh *.tar.bz2
|
||||
popd
|
||||
}
|
||||
|
||||
function test_low() {
|
||||
git clone https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-low-2024-03-12
|
||||
|
||||
./vits/export-onnx.py \
|
||||
--model-type low \
|
||||
--epoch 1600 \
|
||||
--exp-dir ./icefall-tts-ljspeech-vits-low-2024-03-12/exp \
|
||||
--tokens ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt
|
||||
|
||||
ls -lh ./icefall-tts-ljspeech-vits-low-2024-03-12/exp
|
||||
|
||||
./vits/test_onnx.py \
|
||||
--model-filename ./icefall-tts-ljspeech-vits-low-2024-03-12/exp/vits-epoch-1600.onnx \
|
||||
--tokens ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt \
|
||||
--output-filename /icefall/test-low.wav
|
||||
|
||||
ls -lh /icefall/test-low.wav
|
||||
|
||||
d=/icefall/vits-icefall-en_US-ljspeech-low
|
||||
mkdir $d
|
||||
cp -v ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt $d/
|
||||
cp -v ./icefall-tts-ljspeech-vits-low-2024-03-12/exp/vits-epoch-1600.onnx $d/model.onnx
|
||||
|
||||
rm -rf icefall-tts-ljspeech-vits-low-2024-03-12
|
||||
|
||||
pushd $d
|
||||
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
|
||||
tar xf espeak-ng-data.tar.bz2
|
||||
rm espeak-ng-data.tar.bz2
|
||||
cd ..
|
||||
tar cjf vits-icefall-en_US-ljspeech-low.tar.bz2 vits-icefall-en_US-ljspeech-low
|
||||
rm -rf vits-icefall-en_US-ljspeech-low
|
||||
ls -lh *.tar.bz2
|
||||
popd
|
||||
}
|
||||
|
||||
prepare_data
|
||||
train
|
||||
infer
|
||||
export_onnx
|
||||
rm -rf vits/exp-{low,medium,high}
|
||||
test_medium
|
||||
test_low
|
3
.github/workflows/build-doc.yml
vendored
3
.github/workflows/build-doc.yml
vendored
@ -56,11 +56,14 @@ jobs:
|
||||
- name: Build doc
|
||||
shell: bash
|
||||
run: |
|
||||
.github/scripts/generate-piper-phonemize-page.py
|
||||
cd docs
|
||||
python3 -m pip install -r ./requirements.txt
|
||||
make html
|
||||
touch build/html/.nojekyll
|
||||
|
||||
cp -v ../piper_phonemize.html ./build/html/
|
||||
|
||||
- name: Deploy
|
||||
uses: peaceiris/actions-gh-pages@v3
|
||||
with:
|
||||
|
2
.github/workflows/build-docker-image.yml
vendored
2
.github/workflows/build-docker-image.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
image: ["torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
|
||||
image: ["torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
|
||||
|
||||
steps:
|
||||
# refer to https://github.com/actions/checkout
|
||||
|
102
.github/workflows/ljspeech.yml
vendored
Normal file
102
.github/workflows/ljspeech.yml
vendored
Normal file
@ -0,0 +1,102 @@
|
||||
name: ljspeech
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ljspeech-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
generate_build_matrix:
|
||||
if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa'
|
||||
# see https://github.com/pytorch/pytorch/pull/50633
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Generating build matrix
|
||||
id: set-matrix
|
||||
run: |
|
||||
# outputting for debugging purposes
|
||||
python ./.github/scripts/docker/generate_build_matrix.py
|
||||
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py)
|
||||
echo "::set-output name=matrix::${MATRIX}"
|
||||
|
||||
ljspeech:
|
||||
needs: generate_build_matrix
|
||||
name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }}
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Free space
|
||||
shell: bash
|
||||
run: |
|
||||
ls -lh
|
||||
df -h
|
||||
rm -rf /opt/hostedtoolcache
|
||||
df -h
|
||||
echo "pwd: $PWD"
|
||||
echo "github.workspace ${{ github.workspace }}"
|
||||
|
||||
- name: Run tests
|
||||
uses: addnab/docker-run-action@v3
|
||||
with:
|
||||
image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }}
|
||||
options: |
|
||||
--volume ${{ github.workspace }}/:/icefall
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=/icefall:$PYTHONPATH
|
||||
cd /icefall
|
||||
git config --global --add safe.directory /icefall
|
||||
|
||||
.github/scripts/ljspeech/TTS/run.sh
|
||||
|
||||
- name: display files
|
||||
shell: bash
|
||||
run: |
|
||||
ls -lh
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0'
|
||||
with:
|
||||
name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }}
|
||||
path: ./*.wav
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0'
|
||||
with:
|
||||
name: generated-models-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}
|
||||
path: ./*.wav
|
||||
|
||||
- name: Release exported onnx models
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0'
|
||||
uses: svenstaro/upload-release-action@v2
|
||||
with:
|
||||
file_glob: true
|
||||
overwrite: true
|
||||
file: vits-icefall-*.tar.bz2
|
||||
repo_name: k2-fsa/sherpa-onnx
|
||||
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
|
||||
tag: tts-models
|
||||
|
2
.github/workflows/run-docker-image.yml
vendored
2
.github/workflows/run-docker-image.yml
vendored
@ -14,7 +14,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
image: ["torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
|
||||
image: ["torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
|
||||
steps:
|
||||
# refer to https://github.com/actions/checkout
|
||||
- uses: actions/checkout@v2
|
||||
|
8
.github/workflows/style_check.yml
vendored
8
.github/workflows/style_check.yml
vendored
@ -49,7 +49,7 @@ jobs:
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0
|
||||
python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0 isort==5.10.1
|
||||
# Click issue fixed in https://github.com/psf/black/pull/2966
|
||||
|
||||
- name: Run flake8
|
||||
@ -67,3 +67,9 @@ jobs:
|
||||
working-directory: ${{github.workspace}}
|
||||
run: |
|
||||
black --check --diff .
|
||||
|
||||
- name: Run isort
|
||||
shell: bash
|
||||
working-directory: ${{github.workspace}}
|
||||
run: |
|
||||
isort --check --diff .
|
||||
|
@ -26,7 +26,7 @@ repos:
|
||||
# E121,E123,E126,E226,E24,E704,W503,W504
|
||||
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.11.5
|
||||
rev: 5.10.1
|
||||
hooks:
|
||||
- id: isort
|
||||
args: ["--profile=black"]
|
||||
|
@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# python 3.7
|
||||
ARG K2_VERSION="1.24.4.dev20240211+cuda11.3.torch1.12.1"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.3.torch1.12.1"
|
||||
ARG K2_VERSION="1.24.4.dev20240223+cuda11.3.torch1.12.1"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.3.torch1.12.1"
|
||||
ARG TORCHAUDIO_VERSION="0.12.1+cu113"
|
||||
|
||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||
|
@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# python 3.9
|
||||
ARG K2_VERSION="1.24.4.dev20240211+cuda11.6.torch1.13.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.6.torch1.13.0"
|
||||
ARG K2_VERSION="1.24.4.dev20240223+cuda11.6.torch1.13.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.6.torch1.13.0"
|
||||
ARG TORCHAUDIO_VERSION="0.13.0+cu116"
|
||||
|
||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||
|
@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# python 3.7
|
||||
ARG K2_VERSION="1.24.4.dev20240211+cuda10.2.torch1.9.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda10.2.torch1.9.0"
|
||||
ARG K2_VERSION="1.24.4.dev20240223+cuda10.2.torch1.9.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda10.2.torch1.9.0"
|
||||
ARG TORCHAUDIO_VERSION="0.9.0"
|
||||
|
||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||
|
@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# python 3.10
|
||||
ARG K2_VERSION="1.24.4.dev20240211+cuda11.7.torch2.0.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.7.torch2.0.0"
|
||||
ARG K2_VERSION="1.24.4.dev20240223+cuda11.7.torch2.0.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.7.torch2.0.0"
|
||||
ARG TORCHAUDIO_VERSION="2.0.0+cu117"
|
||||
|
||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||
|
@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# python 3.10
|
||||
ARG K2_VERSION="1.24.4.dev20240211+cuda11.8.torch2.1.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.8.torch2.1.0"
|
||||
ARG K2_VERSION="1.24.4.dev20240223+cuda11.8.torch2.1.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.8.torch2.1.0"
|
||||
ARG TORCHAUDIO_VERSION="2.1.0+cu118"
|
||||
|
||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||
|
@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# python 3.10
|
||||
ARG K2_VERSION="1.24.4.dev20240211+cuda12.1.torch2.1.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda12.1.torch2.1.0"
|
||||
ARG K2_VERSION="1.24.4.dev20240223+cuda12.1.torch2.1.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda12.1.torch2.1.0"
|
||||
ARG TORCHAUDIO_VERSION="2.1.0+cu121"
|
||||
|
||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||
|
@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# python 3.10
|
||||
ARG K2_VERSION="1.24.4.dev20240211+cuda11.8.torch2.2.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.8.torch2.2.0"
|
||||
ARG K2_VERSION="1.24.4.dev20240223+cuda11.8.torch2.2.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.8.torch2.2.0"
|
||||
ARG TORCHAUDIO_VERSION="2.2.0+cu118"
|
||||
|
||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||
|
@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# python 3.10
|
||||
ARG K2_VERSION="1.24.4.dev20240211+cuda12.1.torch2.2.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda12.1.torch2.2.0"
|
||||
ARG K2_VERSION="1.24.4.dev20240223+cuda12.1.torch2.2.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda12.1.torch2.2.0"
|
||||
ARG TORCHAUDIO_VERSION="2.2.0+cu121"
|
||||
|
||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||
|
70
docker/torch2.2.1-cuda11.8.dockerfile
Normal file
70
docker/torch2.2.1-cuda11.8.dockerfile
Normal file
@ -0,0 +1,70 @@
|
||||
FROM pytorch/pytorch:2.2.1-cuda11.8-cudnn8-devel
|
||||
|
||||
ENV LC_ALL C.UTF-8
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# python 3.10
|
||||
ARG K2_VERSION="1.24.4.dev20240223+cuda11.8.torch2.2.1"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.8.torch2.2.1"
|
||||
ARG TORCHAUDIO_VERSION="2.2.1+cu118"
|
||||
|
||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||
LABEL k2_version=${K2_VERSION}
|
||||
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
|
||||
LABEL github_repo="https://github.com/k2-fsa/icefall"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
vim \
|
||||
libssl-dev \
|
||||
autoconf \
|
||||
automake \
|
||||
bzip2 \
|
||||
ca-certificates \
|
||||
ffmpeg \
|
||||
g++ \
|
||||
gfortran \
|
||||
git \
|
||||
libtool \
|
||||
make \
|
||||
patch \
|
||||
sox \
|
||||
subversion \
|
||||
unzip \
|
||||
valgrind \
|
||||
wget \
|
||||
zlib1g-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install dependencies
|
||||
RUN pip install --no-cache-dir \
|
||||
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
|
||||
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
|
||||
git+https://github.com/lhotse-speech/lhotse \
|
||||
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
|
||||
kaldi_native_io \
|
||||
kaldialign \
|
||||
kaldifst \
|
||||
kaldilm \
|
||||
sentencepiece>=0.1.96 \
|
||||
tensorboard \
|
||||
typeguard \
|
||||
dill \
|
||||
onnx \
|
||||
onnxruntime \
|
||||
onnxmltools \
|
||||
multi_quantization \
|
||||
typeguard \
|
||||
numpy \
|
||||
pytest \
|
||||
graphviz
|
||||
|
||||
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
|
||||
cd /workspace/icefall && \
|
||||
pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
|
||||
|
||||
WORKDIR /workspace/icefall
|
70
docker/torch2.2.1-cuda12.1.dockerfile
Normal file
70
docker/torch2.2.1-cuda12.1.dockerfile
Normal file
@ -0,0 +1,70 @@
|
||||
FROM pytorch/pytorch:2.2.1-cuda12.1-cudnn8-devel
|
||||
|
||||
ENV LC_ALL C.UTF-8
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# python 3.10
|
||||
ARG K2_VERSION="1.24.4.dev20240223+cuda12.1.torch2.2.1"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda12.1.torch2.2.1"
|
||||
ARG TORCHAUDIO_VERSION="2.2.1+cu121"
|
||||
|
||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||
LABEL k2_version=${K2_VERSION}
|
||||
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
|
||||
LABEL github_repo="https://github.com/k2-fsa/icefall"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
vim \
|
||||
libssl-dev \
|
||||
autoconf \
|
||||
automake \
|
||||
bzip2 \
|
||||
ca-certificates \
|
||||
ffmpeg \
|
||||
g++ \
|
||||
gfortran \
|
||||
git \
|
||||
libtool \
|
||||
make \
|
||||
patch \
|
||||
sox \
|
||||
subversion \
|
||||
unzip \
|
||||
valgrind \
|
||||
wget \
|
||||
zlib1g-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install dependencies
|
||||
RUN pip install --no-cache-dir \
|
||||
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
|
||||
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
|
||||
git+https://github.com/lhotse-speech/lhotse \
|
||||
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
|
||||
kaldi_native_io \
|
||||
kaldialign \
|
||||
kaldifst \
|
||||
kaldilm \
|
||||
sentencepiece>=0.1.96 \
|
||||
tensorboard \
|
||||
typeguard \
|
||||
dill \
|
||||
onnx \
|
||||
onnxruntime \
|
||||
onnxmltools \
|
||||
multi_quantization \
|
||||
typeguard \
|
||||
numpy \
|
||||
pytest \
|
||||
graphviz
|
||||
|
||||
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
|
||||
cd /workspace/icefall && \
|
||||
pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
|
||||
|
||||
WORKDIR /workspace/icefall
|
@ -34,6 +34,8 @@ which will give you something like below:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
"torch2.2.1-cuda12.1"
|
||||
"torch2.2.1-cuda11.8"
|
||||
"torch2.2.0-cuda12.1"
|
||||
"torch2.2.0-cuda11.8"
|
||||
"torch2.1.0-cuda12.1"
|
||||
|
225
docs/source/recipes/Finetune/adapter/finetune_adapter.rst
Normal file
225
docs/source/recipes/Finetune/adapter/finetune_adapter.rst
Normal file
@ -0,0 +1,225 @@
|
||||
Finetune from a pre-trained Zipformer model with adapters
|
||||
=========================================================
|
||||
|
||||
This tutorial shows you how to fine-tune a pre-trained **Zipformer**
|
||||
transducer model on a new dataset with adapters.
|
||||
Adapters are compact and efficient module that can be integrated into a pre-trained model
|
||||
to improve the model's performance on a new domain. Adapters are injected
|
||||
between different modules in the well-trained neural network. During training, only the parameters
|
||||
in the adapters will be updated. It achieves competitive performance
|
||||
while requiring much less GPU memory than full fine-tuning. For more details about adapters,
|
||||
please refer to the original `paper <https://arxiv.org/pdf/1902.00751.pdf#/>`_ for more details.
|
||||
|
||||
.. HINT::
|
||||
|
||||
We assume you have read the page :ref:`install icefall` and have setup
|
||||
the environment for ``icefall``.
|
||||
|
||||
.. HINT::
|
||||
|
||||
We recommend you to use a GPU or several GPUs to run this recipe
|
||||
|
||||
For illustration purpose, we fine-tune the Zipformer transducer model
|
||||
pre-trained on `LibriSpeech`_ on the small subset of `GigaSpeech`_. You could use your
|
||||
own data for fine-tuning if you create a manifest for your new dataset.
|
||||
|
||||
Data preparation
|
||||
----------------
|
||||
|
||||
Please follow the instructions in the `GigaSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/gigaspeech/ASR>`_
|
||||
to prepare the fine-tune data used in this tutorial. We only require the small subset in GigaSpeech for this tutorial.
|
||||
|
||||
|
||||
Model preparation
|
||||
-----------------
|
||||
|
||||
We are using the Zipformer model trained on full LibriSpeech (960 hours) as the intialization. The
|
||||
checkpoint of the model can be downloaded via the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
$ cd icefall-asr-librispeech-zipformer-2023-05-15/exp
|
||||
$ git lfs pull --include "pretrained.pt"
|
||||
$ ln -s pretrained.pt epoch-99.pt
|
||||
$ cd ../data/lang_bpe_500
|
||||
$ git lfs pull --include bpe.model
|
||||
$ cd ../../..
|
||||
|
||||
Before fine-tuning, let's test the model's WER on the new domain. The following command performs
|
||||
decoding on the GigaSpeech test sets:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
./zipformer/decode_gigaspeech.py \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--exp-dir icefall-asr-librispeech-zipformer-2023-05-15/exp \
|
||||
--use-averaged-model 0 \
|
||||
--max-duration 1000 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
You should see the following numbers:
|
||||
|
||||
.. code-block::
|
||||
|
||||
For dev, WER of different settings are:
|
||||
greedy_search 20.06 best for dev
|
||||
|
||||
For test, WER of different settings are:
|
||||
greedy_search 19.27 best for test
|
||||
|
||||
|
||||
Fine-tune with adapter
|
||||
----------------------
|
||||
|
||||
We insert 4 adapters with residual connection in each ``Zipformer2EncoderLayer``.
|
||||
The original model parameters remain untouched during training and only the parameters of
|
||||
the adapters are updated. The following command starts a fine-tuning experiment with adapters:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ do_finetune=1
|
||||
$ use_adapters=1
|
||||
$ adapter_dim=8
|
||||
|
||||
$ ./zipformer_adapter/train.py \
|
||||
--world-size 2 \
|
||||
--num-epochs 20 \
|
||||
--start-epoch 1 \
|
||||
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
|
||||
--use-fp16 1 \
|
||||
--base-lr 0.045 \
|
||||
--use-adapters $use_adapters --adapter-dim $adapter_dim \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--do-finetune $do_finetune \
|
||||
--master-port 13022 \
|
||||
--finetune-ckpt icefall-asr-librispeech-zipformer-2023-05-15/exp/pretrained.pt \
|
||||
--max-duration 1000
|
||||
|
||||
The following arguments are related to fine-tuning:
|
||||
|
||||
- ``--do-finetune``
|
||||
If True, do fine-tuning by initializing the model from a pre-trained checkpoint.
|
||||
**Note that if you want to resume your fine-tuning experiment from certain epochs, you
|
||||
need to set this to False.**
|
||||
|
||||
- ``use-adapters``
|
||||
If adapters are used during fine-tuning.
|
||||
|
||||
- ``--adapter-dim``
|
||||
The bottleneck dimension of the adapter module. Typically a small number.
|
||||
|
||||
You should notice that in the training log, the total number of trainale parameters is shown:
|
||||
|
||||
.. code-block::
|
||||
|
||||
2024-02-22 21:22:03,808 INFO [train.py:1277] A total of 761344 trainable parameters (1.148% of the whole model)
|
||||
|
||||
The trainable parameters only makes up 1.15% of the entire model parameters, so the training will be much faster
|
||||
and requires less memory than full fine-tuning.
|
||||
|
||||
|
||||
Decoding
|
||||
--------
|
||||
|
||||
After training, let's test the WERs. To test the WERs on the GigaSpeech set,
|
||||
you can execute the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ epoch=20
|
||||
$ avg=10
|
||||
$ use_adapters=1
|
||||
$ adapter_dim=8
|
||||
|
||||
% ./zipformer/decode.py \
|
||||
--epoch $epoch \
|
||||
--avg $avg \
|
||||
--use-averaged-model 1 \
|
||||
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
|
||||
--max-duration 600 \
|
||||
--use-adapters $use_adapters \
|
||||
--adapter-dim $adapter_dim \
|
||||
--decoding-method greedy_search
|
||||
|
||||
You should see the following numbers:
|
||||
|
||||
.. code-block::
|
||||
|
||||
For dev, WER of different settings are:
|
||||
greedy_search 15.44 best for dev
|
||||
|
||||
For test, WER of different settings are:
|
||||
greedy_search 15.42 best for test
|
||||
|
||||
|
||||
The WER on test set is improved from 19.27 to 15.42, demonstrating the effectiveness of adapters.
|
||||
|
||||
The same model can be used to perform decoding on LibriSpeech test sets. You can deactivate the adapters
|
||||
to keep the same performance of the original model:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ epoch=20
|
||||
$ avg=1
|
||||
$ use_adapters=0
|
||||
$ adapter_dim=8
|
||||
|
||||
% ./zipformer/decode.py \
|
||||
--epoch $epoch \
|
||||
--avg $avg \
|
||||
--use-averaged-model 1 \
|
||||
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
|
||||
--max-duration 600 \
|
||||
--use-adapters $use_adapters \
|
||||
--adapter-dim $adapter_dim \
|
||||
--decoding-method greedy_search
|
||||
|
||||
|
||||
.. code-block::
|
||||
|
||||
For dev, WER of different settings are:
|
||||
greedy_search 2.23 best for test-clean
|
||||
|
||||
For test, WER of different settings are:
|
||||
greedy_search 4.96 best for test-other
|
||||
|
||||
The numbers are the same as reported in `icefall <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md#normal-scaled-model-number-of-model-parameters-65549011-ie-6555-m>`_. So adapter-based
|
||||
fine-tuning is also very flexible as the same model can be used for decoding on the original and target domain.
|
||||
|
||||
|
||||
Export the model
|
||||
----------------
|
||||
|
||||
After training, the model can be exported to ``onnx`` format easily using the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ use_adapters=1
|
||||
$ adapter_dim=16
|
||||
|
||||
$ ./zipformer_adapter/export-onnx.py \
|
||||
--tokens icefall-asr-librispeech-zipformer-2023-05-15/data/lang_bpe_500/tokens.txt \
|
||||
--use-averaged-model 1 \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
|
||||
--use-adapters $use_adapters \
|
||||
--adapter-dim $adapter_dim \
|
||||
--num-encoder-layers "2,2,3,4,3,2" \
|
||||
--downsampling-factor "1,2,4,8,4,2" \
|
||||
--feedforward-dim "512,768,1024,1536,1024,768" \
|
||||
--num-heads "4,4,4,8,4,4" \
|
||||
--encoder-dim "192,256,384,512,384,256" \
|
||||
--query-head-dim 32 \
|
||||
--value-head-dim 12 \
|
||||
--pos-head-dim 4 \
|
||||
--pos-dim 48 \
|
||||
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
||||
--cnn-module-kernel "31,31,15,15,15,31" \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512 \
|
||||
--causal False \
|
||||
--chunk-size "16,32,64,-1" \
|
||||
--left-context-frames "64,128,256,-1"
|
@ -13,3 +13,4 @@ data to improve the performance on new domains.
|
||||
:caption: Table of Contents
|
||||
|
||||
from_supervised/finetune_zipformer
|
||||
adapter/finetune_adapter
|
||||
|
@ -1,11 +1,11 @@
|
||||
VITS
|
||||
VITS-LJSpeech
|
||||
===============
|
||||
|
||||
This tutorial shows you how to train an VITS model
|
||||
with the `LJSpeech <https://keithito.com/LJ-Speech-Dataset/>`_ dataset.
|
||||
|
||||
.. note::
|
||||
|
||||
|
||||
TTS related recipes require packages in ``requirements-tts.txt``.
|
||||
|
||||
.. note::
|
||||
@ -13,6 +13,14 @@ with the `LJSpeech <https://keithito.com/LJ-Speech-Dataset/>`_ dataset.
|
||||
The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech <https://arxiv.org/pdf/2106.06103.pdf>`_
|
||||
|
||||
|
||||
Install extra dependencies
|
||||
--------------------------
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html
|
||||
pip install numba espnet_tts_frontend
|
||||
|
||||
Data preparation
|
||||
----------------
|
||||
|
||||
@ -56,7 +64,8 @@ Training
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir vits/exp \
|
||||
--tokens data/tokens.txt
|
||||
--tokens data/tokens.txt \
|
||||
--model-type high \
|
||||
--max-duration 500
|
||||
|
||||
.. note::
|
||||
@ -64,6 +73,11 @@ Training
|
||||
You can adjust the hyper-parameters to control the size of the VITS model and
|
||||
the training configurations. For more details, please run ``./vits/train.py --help``.
|
||||
|
||||
.. warning::
|
||||
|
||||
If you want a model that runs faster on CPU, please use ``--model-type low``
|
||||
or ``--model-type medium``.
|
||||
|
||||
.. note::
|
||||
|
||||
The training can take a long time (usually a couple of days).
|
||||
@ -95,8 +109,8 @@ training part first. It will save the ground-truth and generated wavs to the dir
|
||||
Export models
|
||||
-------------
|
||||
|
||||
Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``:
|
||||
``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``.
|
||||
Currently we only support ONNX model exporting. It will generate one file in the given ``exp-dir``:
|
||||
``vits-epoch-*.onnx``.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
@ -120,4 +134,68 @@ Download pretrained models
|
||||
If you don't want to train from scratch, you can download the pretrained models
|
||||
by visiting the following link:
|
||||
|
||||
- `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2023-11-29>`_
|
||||
- ``--model-type=high``: `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2024-02-28>`_
|
||||
- ``--model-type=medium``: `<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-medium-2024-03-12>`_
|
||||
- ``--model-type=low``: `<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-low-2024-03-12>`_
|
||||
|
||||
Usage in sherpa-onnx
|
||||
--------------------
|
||||
|
||||
The following describes how to test the exported ONNX model in `sherpa-onnx`_.
|
||||
|
||||
.. hint::
|
||||
|
||||
`sherpa-onnx`_ supports different programming languages, e.g., C++, C, Python,
|
||||
Kotlin, Java, Swift, Go, C#, etc. It also supports Android and iOS.
|
||||
|
||||
We only describe how to use pre-built binaries from `sherpa-onnx`_ below.
|
||||
Please refer to `<https://k2-fsa.github.io/sherpa/onnx/>`_
|
||||
for more documentation.
|
||||
|
||||
Install sherpa-onnx
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install sherpa-onnx
|
||||
|
||||
To check that you have installed `sherpa-onnx`_ successfully, please run:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
which sherpa-onnx-offline-tts
|
||||
sherpa-onnx-offline-tts --help
|
||||
|
||||
Download lexicon files
|
||||
^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cd /tmp
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
|
||||
tar xf espeak-ng-data.tar.bz2
|
||||
|
||||
Run sherpa-onnx
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cd egs/ljspeech/TTS
|
||||
|
||||
sherpa-onnx-offline-tts \
|
||||
--vits-model=vits/exp/vits-epoch-1000.onnx \
|
||||
--vits-tokens=data/tokens.txt \
|
||||
--vits-data-dir=/tmp/espeak-ng-data \
|
||||
--num-threads=1 \
|
||||
--output-filename=./high.wav \
|
||||
"Ask not what your country can do for you; ask what you can do for your country."
|
||||
|
||||
.. hint::
|
||||
|
||||
You can also use ``sherpa-onnx-offline-tts-play`` to play the audio
|
||||
as it is generating.
|
||||
|
||||
You should get a file ``high.wav`` after running the above command.
|
||||
|
||||
Congratulations! You have successfully trained and exported a text-to-speech
|
||||
model and run it with `sherpa-onnx`_.
|
||||
|
@ -1,11 +1,11 @@
|
||||
VITS
|
||||
VITS-VCTK
|
||||
===============
|
||||
|
||||
This tutorial shows you how to train an VITS model
|
||||
with the `VCTK <https://datashare.ed.ac.uk/handle/10283/3443>`_ dataset.
|
||||
|
||||
.. note::
|
||||
|
||||
|
||||
TTS related recipes require packages in ``requirements-tts.txt``.
|
||||
|
||||
.. note::
|
||||
|
@ -19,7 +19,9 @@ The following table lists the differences among them.
|
||||
| `transducer_stateless_modified` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` |
|
||||
| `transducer_stateless_modified-2` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` + extra data |
|
||||
| `pruned_transducer_stateless3` | Conformer (reworked) | Embedding + Conv1d | pruned RNN-T + reworked model with random combiner + using aidatatang_20zh as extra data|
|
||||
| `pruned_transducer_stateless7` | Zipformer | Embedding | pruned RNN-T + zipformer encoder + stateless decoder with context-size 1 |
|
||||
| `pruned_transducer_stateless7` | Zipformer | Embedding | pruned RNN-T + zipformer encoder + stateless decoder with context-size set to 1 |
|
||||
| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe with context-size set to 1 |
|
||||
|
||||
|
||||
The decoder in `transducer_stateless` is modified from the paper
|
||||
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
|
||||
|
@ -75,7 +75,7 @@ It's reworked Zipformer with Pruned RNNT loss, trained with Byte-level BPE, `voc
|
||||
| fast beam search | 4.43 | 4.17 | --epoch 40 --avg 10 |
|
||||
|
||||
```bash
|
||||
./prepare.sh
|
||||
./prepare.sh
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1"
|
||||
|
||||
|
@ -360,7 +360,7 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
||||
log "Stage 11: Train RNN LM model"
|
||||
log "Stage 12: Train RNN LM model"
|
||||
python ../../../icefall/rnn_lm/train.py \
|
||||
--start-epoch 0 \
|
||||
--world-size 1 \
|
||||
|
@ -89,6 +89,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -881,9 +882,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error()
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
||||
|
@ -85,6 +85,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
@ -878,9 +879,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
||||
|
@ -78,6 +78,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -871,9 +872,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -250,7 +250,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=1,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
@ -78,6 +78,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
@ -882,9 +883,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -78,6 +78,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
@ -881,9 +882,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -86,6 +86,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
@ -985,9 +986,7 @@ def train_one_epoch(
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
|
@ -83,6 +83,7 @@ from icefall.checkpoint import (
|
||||
update_averaged_model,
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -570,9 +571,7 @@ def train_one_epoch(
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
|
@ -1,6 +1,6 @@
|
||||
## Results
|
||||
|
||||
### Aishell2 char-based training results
|
||||
### Aishell2 char-based training results
|
||||
|
||||
#### Pruned transducer stateless 5
|
||||
|
||||
|
@ -29,7 +29,14 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
@ -42,10 +49,12 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
def compute_fbank_aishell2(
|
||||
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
num_jobs = min(8, os.cpu_count())
|
||||
|
||||
dataset_parts = (
|
||||
"train",
|
||||
@ -68,8 +77,12 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
list(manifests.keys()),
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
@ -82,7 +95,7 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition and perturb_speed:
|
||||
logging.info(f"Doing speed perturb")
|
||||
logging.info("Doing speed perturb")
|
||||
cut_set = (
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
@ -111,7 +124,12 @@ def get_args():
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-fbank",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -122,5 +140,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_aishell2(
|
||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
perturb_speed=args.perturb_speed,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
@ -108,6 +108,16 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
whisper_mel_bins=80
|
||||
if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then
|
||||
log "Stage 30: Compute whisper fbank for aishell2"
|
||||
if [ ! -f data/fbank/.aishell2.whisper.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_aishell2.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||
touch data/fbank/.aishell2.whisper.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Compute fbank for musan"
|
||||
if [ ! -f data/fbank/.msuan.done ]; then
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
This recipe contains some various ASR models trained with Aishell4 (including S, M and L three subsets).
|
||||
|
||||
The AISHELL-4 is a sizable real-recorded Mandarin speech dataset collected by 8-channel circular microphone array for speech processing in conference scenarios. The dataset consists of 211 recorded meeting sessions, each containing 4 to 8 speakers, with a total length of 120 hours. This dataset aims to bridge the advanced research on multi-speaker processing and the practical application scenario in three aspects. With real recorded meetings, AISHELL-4 provides realistic acoustics and rich natural speech characteristics in conversation such as short pause, speech overlap, quick speaker turn, noise, etc. Meanwhile, the accurate transcription and speaker voice activity are provided for each meeting in AISHELL-4. This allows the researchers to explore different aspects in meeting processing, ranging from individual tasks such as speech front-end processing, speech recognition and speaker diarization, to multi-modality modeling and joint optimization of relevant tasks.
|
||||
The AISHELL-4 is a sizable real-recorded Mandarin speech dataset collected by 8-channel circular microphone array for speech processing in conference scenarios. The dataset consists of 211 recorded meeting sessions, each containing 4 to 8 speakers, with a total length of 120 hours. This dataset aims to bridge the advanced research on multi-speaker processing and the practical application scenario in three aspects. With real recorded meetings, AISHELL-4 provides realistic acoustics and rich natural speech characteristics in conversation such as short pause, speech overlap, quick speaker turn, noise, etc. Meanwhile, the accurate transcription and speaker voice activity are provided for each meeting in AISHELL-4. This allows the researchers to explore different aspects in meeting processing, ranging from individual tasks such as speech front-end processing, speech recognition and speaker diarization, to multi-modality modeling and joint optimization of relevant tasks.
|
||||
|
||||
(From [Open Speech and Language Resources](https://www.openslr.org/111/))
|
||||
|
||||
|
@ -29,7 +29,14 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
@ -42,10 +49,12 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
def compute_fbank_aishell4(
|
||||
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests/aishell4")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
num_jobs = min(8, os.cpu_count())
|
||||
|
||||
dataset_parts = (
|
||||
"train_S",
|
||||
@ -70,7 +79,12 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
@ -84,7 +98,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition and perturb_speed:
|
||||
logging.info(f"Doing speed perturb")
|
||||
logging.info("Doing speed perturb")
|
||||
cut_set = (
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
@ -95,7 +109,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
# when an executor is specified, make more partitions
|
||||
num_jobs=num_jobs if ex is None else 80,
|
||||
executor=ex,
|
||||
storage_type=ChunkedLilcomHdf5Writer,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
|
||||
logging.info("About splitting cuts into smaller chunks")
|
||||
@ -121,7 +135,12 @@ def get_args():
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-fbank",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -132,5 +151,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_aishell4(
|
||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
perturb_speed=args.perturb_speed,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
@ -6,7 +6,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
set -eou pipefail
|
||||
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
stop_stage=7
|
||||
perturb_speed=true
|
||||
|
||||
|
||||
@ -76,11 +76,21 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Process aishell4"
|
||||
log "Stage 2: Compute fbank for aishell4"
|
||||
if [ ! -f data/fbank/aishell4/.fbank.done ]; then
|
||||
mkdir -p data/fbank/aishell4
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed}
|
||||
touch data/fbank/aishell4/.fbank.done
|
||||
touch data/fbank/.fbank.done
|
||||
fi
|
||||
fi
|
||||
|
||||
whisper_mel_bins=80
|
||||
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
|
||||
log "Stage 20: Compute whisper fbank for aishell4"
|
||||
if [ ! -f data/fbank/aishell4/.fbank.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||
touch data/fbank/.fbank.done
|
||||
fi
|
||||
fi
|
||||
|
||||
@ -106,16 +116,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Compute fbank for aishell4"
|
||||
if [ ! -f data/fbank/.aishell4.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed}
|
||||
touch data/fbank/.aishell4.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Prepare char based lang"
|
||||
log "Stage 5: Prepare char based lang"
|
||||
lang_char_dir=data/lang_char
|
||||
mkdir -p $lang_char_dir
|
||||
|
||||
|
@ -29,7 +29,14 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
@ -42,10 +49,12 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
def compute_fbank_alimeeting(
|
||||
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests/alimeeting")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
num_jobs = min(8, os.cpu_count())
|
||||
|
||||
dataset_parts = (
|
||||
"train",
|
||||
@ -53,7 +62,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
|
||||
"test",
|
||||
)
|
||||
|
||||
prefix = "alimeeting"
|
||||
prefix = "alimeeting-far"
|
||||
suffix = "jsonl.gz"
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=dataset_parts,
|
||||
@ -70,7 +79,12 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
@ -83,7 +97,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition and perturb_speed:
|
||||
logging.info(f"Doing speed perturb")
|
||||
logging.info("Doing speed perturb")
|
||||
cut_set = (
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
@ -121,7 +135,12 @@ def get_args():
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-fbank",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use the Whisper Fbank feature extractor. Default: False.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -132,5 +151,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_alimeeting(
|
||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
perturb_speed=args.perturb_speed,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
@ -6,7 +6,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
set -eou pipefail
|
||||
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
stop_stage=7
|
||||
perturb_speed=true
|
||||
|
||||
# We assume dl_dir (download dir) contains the following
|
||||
@ -15,7 +15,7 @@ perturb_speed=true
|
||||
#
|
||||
# - $dl_dir/alimeeting
|
||||
# This directory contains the following files downloaded from
|
||||
# https://openslr.org/62/
|
||||
# https://openslr.org/119/
|
||||
#
|
||||
# - Train_Ali_far.tar.gz
|
||||
# - Train_Ali_near.tar.gz
|
||||
@ -66,10 +66,21 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Process alimeeting"
|
||||
if [ ! -f data/fbank/alimeeting/.fbank.done ]; then
|
||||
mkdir -p data/fbank/alimeeting
|
||||
log "Stage 2: compute fbank for alimeeting"
|
||||
if [ ! -f data/fbank/.fbank.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed}
|
||||
touch data/fbank/.fbank.done
|
||||
fi
|
||||
fi
|
||||
|
||||
whisper_mel_bins=80
|
||||
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
|
||||
log "Stage 20: compute whisper fbank for alimeeting"
|
||||
if [ ! -f data/fbank/.fbank.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||
touch data/fbank/.fbank.done
|
||||
fi
|
||||
fi
|
||||
|
||||
@ -95,16 +106,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Compute fbank for alimeeting"
|
||||
if [ ! -f data/fbank/.alimeeting.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_alimeeting.py --perturb-speed True
|
||||
touch data/fbank/.alimeeting.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Prepare char based lang"
|
||||
log "Stage 5: Prepare char based lang"
|
||||
lang_char_dir=data/lang_char
|
||||
mkdir -p $lang_char_dir
|
||||
|
||||
|
@ -12,7 +12,7 @@ use_gss=true # Use GSS-based enhancement with MDM setting
|
||||
#
|
||||
# - $dl_dir/alimeeting
|
||||
# This directory contains the following files downloaded from
|
||||
# https://openslr.org/62/
|
||||
# https://openslr.org/119/
|
||||
#
|
||||
# - Train_Ali_far.tar.gz
|
||||
# - Train_Ali_near.tar.gz
|
||||
|
@ -70,6 +70,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
@ -851,9 +852,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -69,6 +69,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
@ -842,9 +843,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -75,6 +75,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -1138,9 +1139,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -75,6 +75,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -1129,9 +1130,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -1 +0,0 @@
|
||||
../../../librispeech/ASR/local/compile_hlg.py
|
168
egs/commonvoice/ASR/local/compile_hlg.py
Executable file
168
egs/commonvoice/ASR/local/compile_hlg.py
Executable file
@ -0,0 +1,168 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Zengrui Jin,)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This script takes as input lang_dir and generates HLG from
|
||||
|
||||
- H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
|
||||
- L, the lexicon, built from lang_dir/L_disambig.pt
|
||||
|
||||
Caution: We use a lexicon that contains disambiguation symbols
|
||||
|
||||
- G, the LM, built from data/lm/G_n_gram.fst.txt
|
||||
|
||||
The generated HLG is saved in $lang_dir/HLG.pt
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lm",
|
||||
type=str,
|
||||
default="G_3_gram",
|
||||
help="""Stem name for LM used in HLG compiling.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
|
||||
lm:
|
||||
The language stem base name.
|
||||
|
||||
Return:
|
||||
An FSA representing HLG.
|
||||
"""
|
||||
lexicon = Lexicon(lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
|
||||
H = k2.ctc_topo(max_token_id)
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
|
||||
if Path(f"{lang_dir}/lm/{lm}.pt").is_file():
|
||||
logging.info(f"Loading pre-compiled {lm}")
|
||||
d = torch.load(f"{lang_dir}/lm/{lm}.pt")
|
||||
G = k2.Fsa.from_dict(d)
|
||||
else:
|
||||
logging.info(f"Loading {lm}.fst.txt")
|
||||
with open(f"{lang_dir}/lm/{lm}.fst.txt") as f:
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
torch.save(G.as_dict(), f"{lang_dir}/lm/{lm}.pt")
|
||||
|
||||
first_token_disambig_id = lexicon.token_table["#0"]
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
|
||||
L = k2.arc_sort(L)
|
||||
G = k2.arc_sort(G)
|
||||
|
||||
logging.info("Intersecting L and G")
|
||||
LG = k2.compose(L, G)
|
||||
logging.info(f"LG shape: {LG.shape}")
|
||||
|
||||
logging.info("Connecting LG")
|
||||
LG = k2.connect(LG)
|
||||
logging.info(f"LG shape after k2.connect: {LG.shape}")
|
||||
|
||||
logging.info(type(LG.aux_labels))
|
||||
logging.info("Determinizing LG")
|
||||
|
||||
LG = k2.determinize(LG)
|
||||
logging.info(type(LG.aux_labels))
|
||||
|
||||
logging.info("Connecting LG after k2.determinize")
|
||||
LG = k2.connect(LG)
|
||||
|
||||
logging.info("Removing disambiguation symbols on LG")
|
||||
|
||||
# LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||
# see https://github.com/k2-fsa/k2/pull/1140
|
||||
labels = LG.labels
|
||||
labels[labels >= first_token_disambig_id] = 0
|
||||
LG.labels = labels
|
||||
|
||||
assert isinstance(LG.aux_labels, k2.RaggedTensor)
|
||||
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
|
||||
|
||||
LG = k2.remove_epsilon(LG)
|
||||
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
|
||||
|
||||
LG = k2.connect(LG)
|
||||
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
|
||||
|
||||
logging.info("Arc sorting LG")
|
||||
LG = k2.arc_sort(LG)
|
||||
|
||||
logging.info("Composing H and LG")
|
||||
# CAUTION: The name of the inner_labels is fixed
|
||||
# to `tokens`. If you want to change it, please
|
||||
# also change other places in icefall that are using
|
||||
# it.
|
||||
HLG = k2.compose(H, LG, inner_labels="tokens")
|
||||
|
||||
logging.info("Connecting LG")
|
||||
HLG = k2.connect(HLG)
|
||||
|
||||
logging.info("Arc sorting LG")
|
||||
HLG = k2.arc_sort(HLG)
|
||||
logging.info(f"HLG.shape: {HLG.shape}")
|
||||
|
||||
return HLG
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
|
||||
if (lang_dir / "HLG.pt").is_file():
|
||||
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
|
||||
return
|
||||
|
||||
logging.info(f"Processing {lang_dir}")
|
||||
|
||||
HLG = compile_HLG(lang_dir, args.lm)
|
||||
logging.info(f"Saving HLG.pt to {lang_dir}")
|
||||
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
@ -1 +0,0 @@
|
||||
../../../librispeech/ASR/local/compile_lg.py
|
149
egs/commonvoice/ASR/local/compile_lg.py
Executable file
149
egs/commonvoice/ASR/local/compile_lg.py
Executable file
@ -0,0 +1,149 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Kang Wei,
|
||||
# Zengrui Jin,)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This script takes as input lang_dir and generates LG from
|
||||
|
||||
- L, the lexicon, built from lang_dir/L_disambig.pt
|
||||
|
||||
Caution: We use a lexicon that contains disambiguation symbols
|
||||
|
||||
- G, the LM, built from lang_dir/lm/G_3_gram.fst.txt
|
||||
|
||||
The generated LG is saved in $lang_dir/LG.pt
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lm",
|
||||
type=str,
|
||||
default="G_3_gram",
|
||||
help="""Stem name for LM used in HLG compiling.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
|
||||
|
||||
Return:
|
||||
An FSA representing LG.
|
||||
"""
|
||||
lexicon = Lexicon(lang_dir)
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
|
||||
if Path(f"{lang_dir}/lm/{lm}.pt").is_file():
|
||||
logging.info(f"Loading pre-compiled {lm}")
|
||||
d = torch.load(f"{lang_dir}/lm/{lm}.pt")
|
||||
G = k2.Fsa.from_dict(d)
|
||||
else:
|
||||
logging.info(f"Loading {lm}.fst.txt")
|
||||
with open(f"{lang_dir}/lm/{lm}.fst.txt") as f:
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
torch.save(G.as_dict(), f"{lang_dir}/lm/{lm}.pt")
|
||||
|
||||
first_token_disambig_id = lexicon.token_table["#0"]
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
|
||||
L = k2.arc_sort(L)
|
||||
G = k2.arc_sort(G)
|
||||
|
||||
logging.info("Intersecting L and G")
|
||||
LG = k2.compose(L, G)
|
||||
logging.info(f"LG shape: {LG.shape}")
|
||||
|
||||
logging.info("Connecting LG")
|
||||
LG = k2.connect(LG)
|
||||
logging.info(f"LG shape after k2.connect: {LG.shape}")
|
||||
|
||||
logging.info(type(LG.aux_labels))
|
||||
logging.info("Determinizing LG")
|
||||
|
||||
LG = k2.determinize(LG, k2.DeterminizeWeightPushingType.kLogWeightPushing)
|
||||
logging.info(type(LG.aux_labels))
|
||||
|
||||
logging.info("Connecting LG after k2.determinize")
|
||||
LG = k2.connect(LG)
|
||||
|
||||
logging.info("Removing disambiguation symbols on LG")
|
||||
|
||||
# LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||
# see https://github.com/k2-fsa/k2/pull/1140
|
||||
labels = LG.labels
|
||||
labels[labels >= first_token_disambig_id] = 0
|
||||
LG.labels = labels
|
||||
|
||||
assert isinstance(LG.aux_labels, k2.RaggedTensor)
|
||||
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
|
||||
|
||||
LG = k2.remove_epsilon(LG)
|
||||
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
|
||||
|
||||
LG = k2.connect(LG)
|
||||
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
|
||||
|
||||
logging.info("Arc sorting LG")
|
||||
LG = k2.arc_sort(LG)
|
||||
|
||||
return LG
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
|
||||
if (lang_dir / "LG.pt").is_file():
|
||||
logging.info(f"{lang_dir}/LG.pt already exists - skipping")
|
||||
return
|
||||
|
||||
logging.info(f"Processing {lang_dir}")
|
||||
|
||||
LG = compile_LG(lang_dir, args.lm)
|
||||
logging.info(f"Saving LG.pt to {lang_dir}")
|
||||
torch.save(LG.as_dict(), f"{lang_dir}/LG.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
@ -48,8 +48,27 @@ def normalize_text(utt: str, language: str) -> str:
|
||||
utt = re.sub("’", "'", utt)
|
||||
if language == "en":
|
||||
return re.sub(r"[^a-zA-Z\s]", "", utt).upper()
|
||||
if language == "fr":
|
||||
elif language == "fr":
|
||||
return re.sub(r"[^A-ZÀÂÆÇÉÈÊËÎÏÔŒÙÛÜ' ]", "", utt).upper()
|
||||
elif language == "pl":
|
||||
return re.sub(r"[^a-ząćęłńóśźżA-ZĄĆĘŁŃÓŚŹŻ' ]", "", utt).upper()
|
||||
elif language == "yue":
|
||||
return (
|
||||
utt.replace(" ", "")
|
||||
.replace(",", "")
|
||||
.replace("。", " ")
|
||||
.replace("?", "")
|
||||
.replace("!", "")
|
||||
.replace("?", "")
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"""
|
||||
Text normalization not implemented for language: {language},
|
||||
please consider implementing it in the local/preprocess_commonvoice.py
|
||||
or raise an issue on GitHub to request it.
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def preprocess_commonvoice(
|
||||
|
@ -381,9 +381,11 @@ class CommonVoiceAsrDataModule:
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)(),
|
||||
input_strategy=(
|
||||
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)()
|
||||
),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
|
@ -79,10 +79,10 @@ It will generate the following 3 files inside $repo/exp:
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from icefall import is_module_available
|
||||
import torch
|
||||
from onnx_pretrained import OnnxModel
|
||||
|
||||
import torch
|
||||
from icefall import is_module_available
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -79,6 +79,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -871,9 +872,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -31,7 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
@ -232,7 +232,7 @@ class CommonVoiceAsrDataModule:
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
@ -315,8 +315,8 @@ class CommonVoiceAsrDataModule:
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
@ -383,9 +383,11 @@ class CommonVoiceAsrDataModule:
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)(),
|
||||
input_strategy=(
|
||||
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)()
|
||||
),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
|
@ -889,9 +889,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise RuntimeError(f", exiting: {cur_grad_scale}")
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -81,6 +81,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -965,9 +966,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -78,6 +78,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
@ -888,9 +889,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -81,6 +81,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
@ -909,9 +910,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -70,9 +70,9 @@ import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from tokenizer import Tokenizer
|
||||
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -81,6 +81,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
@ -908,9 +909,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -23,6 +23,7 @@ from pathlib import Path
|
||||
|
||||
from lhotse import CutSet, SupervisionSegment
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
# Similar text filtering and normalization procedure as in:
|
||||
|
@ -76,6 +76,7 @@ from beam_search import (
|
||||
)
|
||||
from gigaspeech_scoring import asr_text_post_processing
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
|
@ -88,7 +88,7 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import GigaSpeechAsrDataModule
|
||||
from train import add_model_arguments, get_params, get_model
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -51,7 +51,7 @@ from streaming_beam_search import (
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_model
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -89,6 +89,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -1031,9 +1032,7 @@ def train_one_epoch(
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
|
@ -42,12 +42,10 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import GigaSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
keywords_search,
|
||||
)
|
||||
from beam_search import keywords_search
|
||||
from lhotse.cut import Cut
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from lhotse.cut import Cut
|
||||
from icefall import ContextGraph
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -76,25 +76,6 @@ from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.checkpoint import remove_checkpoints
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.checkpoint import (
|
||||
save_checkpoint_with_global_batch_idx,
|
||||
update_averaged_model,
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
from train import (
|
||||
add_model_arguments,
|
||||
add_training_arguments,
|
||||
@ -110,6 +91,25 @@ from train import (
|
||||
set_batch_count,
|
||||
)
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.checkpoint import remove_checkpoints
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.checkpoint import (
|
||||
save_checkpoint_with_global_batch_idx,
|
||||
update_averaged_model,
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
|
||||
@ -372,9 +372,7 @@ def train_one_epoch(
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
|
@ -89,6 +89,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -1034,9 +1035,7 @@ def train_one_epoch(
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
|
@ -85,6 +85,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -1169,9 +1170,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -81,6 +81,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -1056,9 +1057,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -232,7 +232,7 @@ class LibriHeavyAsrDataModule:
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
@ -93,6 +93,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -1036,9 +1037,7 @@ def train_one_epoch(
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
|
@ -425,9 +425,11 @@ class LibriHeavyAsrDataModule:
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
input_strategy=(
|
||||
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures()
|
||||
),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
|
@ -103,6 +103,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -1051,9 +1052,7 @@ def train_one_epoch(
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
|
@ -117,6 +117,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -855,9 +856,9 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
decoder_dim=params.decoder_dim,
|
||||
joiner_dim=params.joiner_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
context_dim=4 * 768
|
||||
if params.context_injection
|
||||
else -1, # the output dim of text encoder
|
||||
context_dim=(
|
||||
4 * 768 if params.context_injection else -1
|
||||
), # the output dim of text encoder
|
||||
context_injection=params.context_injection,
|
||||
)
|
||||
return joiner
|
||||
@ -1398,9 +1399,7 @@ def train_one_epoch(
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
|
@ -35,8 +35,7 @@ The following table lists the differences among them.
|
||||
| `lstm_transducer_stateless2` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gigaspeech (multi-dataset setup) |
|
||||
| `lstm_transducer_stateless3` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gradient filter + delay penalty |
|
||||
| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe |
|
||||
| `zipformer_adapter` | Upgraded Zipformer | Embedding + Conv1d | Finetune `zipformer` with adapter |
|
||||
| `zipformer_lora` | Upgraded Zipformer | Embedding + Conv1d | Finetune `zipformer` with LoRA |
|
||||
| `zipformer_adapter` | Upgraded Zipformer | Embedding + Conv1d | It supports domain adaptation of Zipformer using parameter efficient adapters |
|
||||
|
||||
The decoder in `transducer_stateless` is modified from the paper
|
||||
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
|
||||
|
@ -24,8 +24,7 @@ To run this file, do:
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from train import get_params, get_ctc_model
|
||||
from train import get_ctc_model, get_params
|
||||
|
||||
|
||||
def test_model():
|
||||
|
@ -59,9 +59,9 @@ import onnx
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from decoder import Decoder
|
||||
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
|
||||
from emformer import Emformer
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -39,7 +39,7 @@ Usage of this script:
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
@ -47,7 +47,6 @@ import torch
|
||||
import torchaudio
|
||||
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from typing import Optional, List
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -31,28 +31,28 @@ https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stat
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import torch.multiprocessing as mp
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import AsrDataModule
|
||||
from beam_search import (
|
||||
fast_beam_search_one_best,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from icefall.utils import AttributeDict, convert_timestamp, setup_logger
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.supervision import AlignmentItem
|
||||
from lhotse.serialization import SequentialJsonlWriter
|
||||
from lhotse.supervision import AlignmentItem
|
||||
|
||||
from icefall.utils import AttributeDict, convert_timestamp, setup_logger
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -73,12 +73,11 @@ It will generate the following 3 files inside $repo/exp:
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from onnx_pretrained import OnnxModel
|
||||
|
||||
from icefall import is_module_available
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -22,11 +22,12 @@ Usage: ./pruned_transducer_stateless/my_profile.py
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from train import add_model_arguments, get_encoder_model, get_params
|
||||
|
||||
from icefall.profiler import get_model_profile
|
||||
from train import get_encoder_model, add_model_arguments, get_params
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -75,8 +75,7 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
|
||||
from onnx_pretrained import greedy_search, OnnxModel
|
||||
from onnx_pretrained import OnnxModel, greedy_search
|
||||
|
||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||
|
||||
|
@ -78,10 +78,10 @@ It will generate the following 3 files inside $repo/exp:
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from icefall import is_module_available
|
||||
import torch
|
||||
from onnx_pretrained import OnnxModel
|
||||
|
||||
import torch
|
||||
from icefall import is_module_available
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -76,8 +76,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import AsrDataModule
|
||||
from librispeech import LibriSpeech
|
||||
|
||||
from onnx_pretrained import greedy_search, OnnxModel
|
||||
from onnx_pretrained import OnnxModel, greedy_search
|
||||
|
||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||
|
||||
|
@ -22,15 +22,15 @@ Usage: ./pruned_transducer_stateless4/my_profile.py
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
|
||||
from typing import Tuple
|
||||
from scaling import BasicNorm, DoubleSwish
|
||||
from torch import Tensor, nn
|
||||
from train import add_model_arguments, get_encoder_model, get_joiner_model, get_params
|
||||
|
||||
from icefall.profiler import get_model_profile
|
||||
from scaling import BasicNorm, DoubleSwish
|
||||
from train import get_encoder_model, get_joiner_model, add_model_arguments, get_params
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -82,8 +82,7 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
|
||||
from onnx_pretrained import greedy_search, OnnxModel
|
||||
from onnx_pretrained import OnnxModel, greedy_search
|
||||
|
||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||
|
||||
|
@ -20,7 +20,6 @@ from typing import List
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||
|
||||
# The force alignment problem can be formulated as finding
|
||||
|
@ -107,9 +107,6 @@ import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from gigaspeech import GigaSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
@ -120,6 +117,9 @@ from beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
|
||||
# from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from gigaspeech import GigaSpeechAsrDataModule
|
||||
from gigaspeech_scoring import asr_text_post_processing
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
|
@ -80,6 +80,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -976,9 +977,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -65,16 +65,15 @@ from typing import Dict, List
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.utils import str2bool
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -22,15 +22,15 @@ Usage: ./pruned_transducer_stateless7/my_profile.py
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
|
||||
from typing import Tuple
|
||||
from scaling import BasicNorm, DoubleSwish
|
||||
from torch import Tensor, nn
|
||||
from train import add_model_arguments, get_encoder_model, get_joiner_model, get_params
|
||||
|
||||
from icefall.profiler import get_model_profile
|
||||
from scaling import BasicNorm, DoubleSwish
|
||||
from train import get_encoder_model, get_joiner_model, add_model_arguments, get_params
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -75,8 +75,7 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
|
||||
from onnx_pretrained import greedy_search, OnnxModel
|
||||
from onnx_pretrained import OnnxModel, greedy_search
|
||||
|
||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||
|
||||
|
@ -24,7 +24,6 @@ To run this file, do:
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
|
@ -81,6 +81,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -878,9 +879,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -81,6 +81,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -902,9 +903,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user