mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
resolve conflict
This commit is contained in:
commit
390f01653f
1
.github/scripts/.gitignore
vendored
Normal file
1
.github/scripts/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
piper_phonemize.html
|
29
.github/scripts/generate-piper-phonemize-page.py
vendored
Executable file
29
.github/scripts/generate-piper-phonemize-page.py
vendored
Executable file
@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
|
||||
def main():
|
||||
prefix = (
|
||||
"https://github.com/csukuangfj/piper-phonemize/releases/download/2023.12.5/"
|
||||
)
|
||||
files = [
|
||||
"piper_phonemize-1.2.0-cp310-cp310-macosx_10_14_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp311-cp311-macosx_10_14_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp312-cp312-macosx_10_14_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp37-cp37m-macosx_10_14_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp38-cp38-macosx_10_14_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp39-cp39-macosx_10_14_x86_64.whl",
|
||||
"piper_phonemize-1.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
|
||||
]
|
||||
with open("piper_phonemize.html", "w") as f:
|
||||
for file in files:
|
||||
url = prefix + file
|
||||
f.write(f'<a href="{url}">{file}</a><br/>\n')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
6
.github/scripts/librispeech/ASR/run.sh
vendored
6
.github/scripts/librispeech/ASR/run.sh
vendored
@ -15,9 +15,9 @@ function prepare_data() {
|
||||
# cause OOM error for CI later.
|
||||
mkdir -p download/lm
|
||||
pushd download/lm
|
||||
wget -q http://www.openslr.org/resources/11/librispeech-vocab.txt
|
||||
wget -q http://www.openslr.org/resources/11/librispeech-lexicon.txt
|
||||
wget -q http://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz
|
||||
wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-lm-norm.txt.gz
|
||||
wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-lexicon.txt
|
||||
wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-vocab.txt
|
||||
ls -lh
|
||||
gunzip librispeech-lm-norm.txt.gz
|
||||
|
||||
|
157
.github/scripts/ljspeech/TTS/run.sh
vendored
Executable file
157
.github/scripts/ljspeech/TTS/run.sh
vendored
Executable file
@ -0,0 +1,157 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -ex
|
||||
|
||||
python3 -m pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html
|
||||
python3 -m pip install espnet_tts_frontend
|
||||
python3 -m pip install numba
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
cd egs/ljspeech/TTS
|
||||
|
||||
sed -i.bak s/600/8/g ./prepare.sh
|
||||
sed -i.bak s/"first 100"/"first 3"/g ./prepare.sh
|
||||
sed -i.bak s/500/5/g ./prepare.sh
|
||||
git diff
|
||||
|
||||
function prepare_data() {
|
||||
# We have created a subset of the data for testing
|
||||
#
|
||||
mkdir download
|
||||
pushd download
|
||||
wget -q https://huggingface.co/csukuangfj/ljspeech-subset-for-ci-test/resolve/main/LJSpeech-1.1.tar.bz2
|
||||
tar xvf LJSpeech-1.1.tar.bz2
|
||||
popd
|
||||
|
||||
./prepare.sh
|
||||
tree .
|
||||
}
|
||||
|
||||
function train() {
|
||||
pushd ./vits
|
||||
sed -i.bak s/200/3/g ./train.py
|
||||
git diff .
|
||||
popd
|
||||
|
||||
for t in low medium high; do
|
||||
./vits/train.py \
|
||||
--exp-dir vits/exp-$t \
|
||||
--model-type $t \
|
||||
--num-epochs 1 \
|
||||
--save-every-n 1 \
|
||||
--num-buckets 2 \
|
||||
--tokens data/tokens.txt \
|
||||
--max-duration 20
|
||||
|
||||
ls -lh vits/exp-$t
|
||||
done
|
||||
}
|
||||
|
||||
function infer() {
|
||||
for t in low medium high; do
|
||||
./vits/infer.py \
|
||||
--num-buckets 2 \
|
||||
--model-type $t \
|
||||
--epoch 1 \
|
||||
--exp-dir ./vits/exp-$t \
|
||||
--tokens data/tokens.txt \
|
||||
--max-duration 20
|
||||
done
|
||||
}
|
||||
|
||||
function export_onnx() {
|
||||
for t in low medium high; do
|
||||
./vits/export-onnx.py \
|
||||
--model-type $t \
|
||||
--epoch 1 \
|
||||
--exp-dir ./vits/exp-$t \
|
||||
--tokens data/tokens.txt
|
||||
|
||||
ls -lh vits/exp-$t/
|
||||
done
|
||||
}
|
||||
|
||||
function test_medium() {
|
||||
git clone https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-medium-2024-03-12
|
||||
|
||||
./vits/export-onnx.py \
|
||||
--model-type medium \
|
||||
--epoch 820 \
|
||||
--exp-dir ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp \
|
||||
--tokens ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt
|
||||
|
||||
ls -lh ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp
|
||||
|
||||
./vits/test_onnx.py \
|
||||
--model-filename ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp/vits-epoch-820.onnx \
|
||||
--tokens ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt \
|
||||
--output-filename /icefall/test-medium.wav
|
||||
|
||||
ls -lh /icefall/test-medium.wav
|
||||
|
||||
d=/icefall/vits-icefall-en_US-ljspeech-medium
|
||||
mkdir $d
|
||||
cp -v ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt $d/
|
||||
cp -v ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp/vits-epoch-820.onnx $d/model.onnx
|
||||
|
||||
rm -rf icefall-tts-ljspeech-vits-medium-2024-03-12
|
||||
|
||||
pushd $d
|
||||
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
|
||||
tar xf espeak-ng-data.tar.bz2
|
||||
rm espeak-ng-data.tar.bz2
|
||||
cd ..
|
||||
tar cjf vits-icefall-en_US-ljspeech-medium.tar.bz2 vits-icefall-en_US-ljspeech-medium
|
||||
rm -rf vits-icefall-en_US-ljspeech-medium
|
||||
ls -lh *.tar.bz2
|
||||
popd
|
||||
}
|
||||
|
||||
function test_low() {
|
||||
git clone https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-low-2024-03-12
|
||||
|
||||
./vits/export-onnx.py \
|
||||
--model-type low \
|
||||
--epoch 1600 \
|
||||
--exp-dir ./icefall-tts-ljspeech-vits-low-2024-03-12/exp \
|
||||
--tokens ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt
|
||||
|
||||
ls -lh ./icefall-tts-ljspeech-vits-low-2024-03-12/exp
|
||||
|
||||
./vits/test_onnx.py \
|
||||
--model-filename ./icefall-tts-ljspeech-vits-low-2024-03-12/exp/vits-epoch-1600.onnx \
|
||||
--tokens ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt \
|
||||
--output-filename /icefall/test-low.wav
|
||||
|
||||
ls -lh /icefall/test-low.wav
|
||||
|
||||
d=/icefall/vits-icefall-en_US-ljspeech-low
|
||||
mkdir $d
|
||||
cp -v ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt $d/
|
||||
cp -v ./icefall-tts-ljspeech-vits-low-2024-03-12/exp/vits-epoch-1600.onnx $d/model.onnx
|
||||
|
||||
rm -rf icefall-tts-ljspeech-vits-low-2024-03-12
|
||||
|
||||
pushd $d
|
||||
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
|
||||
tar xf espeak-ng-data.tar.bz2
|
||||
rm espeak-ng-data.tar.bz2
|
||||
cd ..
|
||||
tar cjf vits-icefall-en_US-ljspeech-low.tar.bz2 vits-icefall-en_US-ljspeech-low
|
||||
rm -rf vits-icefall-en_US-ljspeech-low
|
||||
ls -lh *.tar.bz2
|
||||
popd
|
||||
}
|
||||
|
||||
prepare_data
|
||||
train
|
||||
infer
|
||||
export_onnx
|
||||
rm -rf vits/exp-{low,medium,high}
|
||||
test_medium
|
||||
test_low
|
3
.github/workflows/build-doc.yml
vendored
3
.github/workflows/build-doc.yml
vendored
@ -56,11 +56,14 @@ jobs:
|
||||
- name: Build doc
|
||||
shell: bash
|
||||
run: |
|
||||
.github/scripts/generate-piper-phonemize-page.py
|
||||
cd docs
|
||||
python3 -m pip install -r ./requirements.txt
|
||||
make html
|
||||
touch build/html/.nojekyll
|
||||
|
||||
cp -v ../piper_phonemize.html ./build/html/
|
||||
|
||||
- name: Deploy
|
||||
uses: peaceiris/actions-gh-pages@v3
|
||||
with:
|
||||
|
2
.github/workflows/build-docker-image.yml
vendored
2
.github/workflows/build-docker-image.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
image: ["torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
|
||||
image: ["torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
|
||||
|
||||
steps:
|
||||
# refer to https://github.com/actions/checkout
|
||||
|
102
.github/workflows/ljspeech.yml
vendored
Normal file
102
.github/workflows/ljspeech.yml
vendored
Normal file
@ -0,0 +1,102 @@
|
||||
name: ljspeech
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ljspeech-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
generate_build_matrix:
|
||||
if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa'
|
||||
# see https://github.com/pytorch/pytorch/pull/50633
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Generating build matrix
|
||||
id: set-matrix
|
||||
run: |
|
||||
# outputting for debugging purposes
|
||||
python ./.github/scripts/docker/generate_build_matrix.py
|
||||
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py)
|
||||
echo "::set-output name=matrix::${MATRIX}"
|
||||
|
||||
ljspeech:
|
||||
needs: generate_build_matrix
|
||||
name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }}
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Free space
|
||||
shell: bash
|
||||
run: |
|
||||
ls -lh
|
||||
df -h
|
||||
rm -rf /opt/hostedtoolcache
|
||||
df -h
|
||||
echo "pwd: $PWD"
|
||||
echo "github.workspace ${{ github.workspace }}"
|
||||
|
||||
- name: Run tests
|
||||
uses: addnab/docker-run-action@v3
|
||||
with:
|
||||
image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }}
|
||||
options: |
|
||||
--volume ${{ github.workspace }}/:/icefall
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=/icefall:$PYTHONPATH
|
||||
cd /icefall
|
||||
git config --global --add safe.directory /icefall
|
||||
|
||||
.github/scripts/ljspeech/TTS/run.sh
|
||||
|
||||
- name: display files
|
||||
shell: bash
|
||||
run: |
|
||||
ls -lh
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0'
|
||||
with:
|
||||
name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }}
|
||||
path: ./*.wav
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0'
|
||||
with:
|
||||
name: generated-models-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}
|
||||
path: ./*.wav
|
||||
|
||||
- name: Release exported onnx models
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0'
|
||||
uses: svenstaro/upload-release-action@v2
|
||||
with:
|
||||
file_glob: true
|
||||
overwrite: true
|
||||
file: vits-icefall-*.tar.bz2
|
||||
repo_name: k2-fsa/sherpa-onnx
|
||||
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
|
||||
tag: tts-models
|
||||
|
2
.github/workflows/run-docker-image.yml
vendored
2
.github/workflows/run-docker-image.yml
vendored
@ -14,7 +14,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
image: ["torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
|
||||
image: ["torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
|
||||
steps:
|
||||
# refer to https://github.com/actions/checkout
|
||||
- uses: actions/checkout@v2
|
||||
|
8
.github/workflows/style_check.yml
vendored
8
.github/workflows/style_check.yml
vendored
@ -49,7 +49,7 @@ jobs:
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0
|
||||
python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0 isort==5.10.1
|
||||
# Click issue fixed in https://github.com/psf/black/pull/2966
|
||||
|
||||
- name: Run flake8
|
||||
@ -67,3 +67,9 @@ jobs:
|
||||
working-directory: ${{github.workspace}}
|
||||
run: |
|
||||
black --check --diff .
|
||||
|
||||
- name: Run isort
|
||||
shell: bash
|
||||
working-directory: ${{github.workspace}}
|
||||
run: |
|
||||
isort --check --diff .
|
||||
|
@ -26,7 +26,7 @@ repos:
|
||||
# E121,E123,E126,E226,E24,E704,W503,W504
|
||||
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.11.5
|
||||
rev: 5.10.1
|
||||
hooks:
|
||||
- id: isort
|
||||
args: ["--profile=black"]
|
||||
|
@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# python 3.7
|
||||
ARG K2_VERSION="1.24.4.dev20240211+cuda11.3.torch1.12.1"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.3.torch1.12.1"
|
||||
ARG K2_VERSION="1.24.4.dev20240223+cuda11.3.torch1.12.1"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.3.torch1.12.1"
|
||||
ARG TORCHAUDIO_VERSION="0.12.1+cu113"
|
||||
|
||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||
|
@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# python 3.9
|
||||
ARG K2_VERSION="1.24.4.dev20240211+cuda11.6.torch1.13.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.6.torch1.13.0"
|
||||
ARG K2_VERSION="1.24.4.dev20240223+cuda11.6.torch1.13.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.6.torch1.13.0"
|
||||
ARG TORCHAUDIO_VERSION="0.13.0+cu116"
|
||||
|
||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||
|
@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# python 3.7
|
||||
ARG K2_VERSION="1.24.4.dev20240211+cuda10.2.torch1.9.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda10.2.torch1.9.0"
|
||||
ARG K2_VERSION="1.24.4.dev20240223+cuda10.2.torch1.9.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda10.2.torch1.9.0"
|
||||
ARG TORCHAUDIO_VERSION="0.9.0"
|
||||
|
||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||
|
@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# python 3.10
|
||||
ARG K2_VERSION="1.24.4.dev20240211+cuda11.7.torch2.0.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.7.torch2.0.0"
|
||||
ARG K2_VERSION="1.24.4.dev20240223+cuda11.7.torch2.0.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.7.torch2.0.0"
|
||||
ARG TORCHAUDIO_VERSION="2.0.0+cu117"
|
||||
|
||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||
|
@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# python 3.10
|
||||
ARG K2_VERSION="1.24.4.dev20240211+cuda11.8.torch2.1.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.8.torch2.1.0"
|
||||
ARG K2_VERSION="1.24.4.dev20240223+cuda11.8.torch2.1.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.8.torch2.1.0"
|
||||
ARG TORCHAUDIO_VERSION="2.1.0+cu118"
|
||||
|
||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||
|
@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# python 3.10
|
||||
ARG K2_VERSION="1.24.4.dev20240211+cuda12.1.torch2.1.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda12.1.torch2.1.0"
|
||||
ARG K2_VERSION="1.24.4.dev20240223+cuda12.1.torch2.1.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda12.1.torch2.1.0"
|
||||
ARG TORCHAUDIO_VERSION="2.1.0+cu121"
|
||||
|
||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||
|
@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# python 3.10
|
||||
ARG K2_VERSION="1.24.4.dev20240211+cuda11.8.torch2.2.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.8.torch2.2.0"
|
||||
ARG K2_VERSION="1.24.4.dev20240223+cuda11.8.torch2.2.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.8.torch2.2.0"
|
||||
ARG TORCHAUDIO_VERSION="2.2.0+cu118"
|
||||
|
||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||
|
@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# python 3.10
|
||||
ARG K2_VERSION="1.24.4.dev20240211+cuda12.1.torch2.2.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda12.1.torch2.2.0"
|
||||
ARG K2_VERSION="1.24.4.dev20240223+cuda12.1.torch2.2.0"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda12.1.torch2.2.0"
|
||||
ARG TORCHAUDIO_VERSION="2.2.0+cu121"
|
||||
|
||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||
|
70
docker/torch2.2.1-cuda11.8.dockerfile
Normal file
70
docker/torch2.2.1-cuda11.8.dockerfile
Normal file
@ -0,0 +1,70 @@
|
||||
FROM pytorch/pytorch:2.2.1-cuda11.8-cudnn8-devel
|
||||
|
||||
ENV LC_ALL C.UTF-8
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# python 3.10
|
||||
ARG K2_VERSION="1.24.4.dev20240223+cuda11.8.torch2.2.1"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.8.torch2.2.1"
|
||||
ARG TORCHAUDIO_VERSION="2.2.1+cu118"
|
||||
|
||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||
LABEL k2_version=${K2_VERSION}
|
||||
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
|
||||
LABEL github_repo="https://github.com/k2-fsa/icefall"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
vim \
|
||||
libssl-dev \
|
||||
autoconf \
|
||||
automake \
|
||||
bzip2 \
|
||||
ca-certificates \
|
||||
ffmpeg \
|
||||
g++ \
|
||||
gfortran \
|
||||
git \
|
||||
libtool \
|
||||
make \
|
||||
patch \
|
||||
sox \
|
||||
subversion \
|
||||
unzip \
|
||||
valgrind \
|
||||
wget \
|
||||
zlib1g-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install dependencies
|
||||
RUN pip install --no-cache-dir \
|
||||
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
|
||||
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
|
||||
git+https://github.com/lhotse-speech/lhotse \
|
||||
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
|
||||
kaldi_native_io \
|
||||
kaldialign \
|
||||
kaldifst \
|
||||
kaldilm \
|
||||
sentencepiece>=0.1.96 \
|
||||
tensorboard \
|
||||
typeguard \
|
||||
dill \
|
||||
onnx \
|
||||
onnxruntime \
|
||||
onnxmltools \
|
||||
multi_quantization \
|
||||
typeguard \
|
||||
numpy \
|
||||
pytest \
|
||||
graphviz
|
||||
|
||||
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
|
||||
cd /workspace/icefall && \
|
||||
pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
|
||||
|
||||
WORKDIR /workspace/icefall
|
70
docker/torch2.2.1-cuda12.1.dockerfile
Normal file
70
docker/torch2.2.1-cuda12.1.dockerfile
Normal file
@ -0,0 +1,70 @@
|
||||
FROM pytorch/pytorch:2.2.1-cuda12.1-cudnn8-devel
|
||||
|
||||
ENV LC_ALL C.UTF-8
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# python 3.10
|
||||
ARG K2_VERSION="1.24.4.dev20240223+cuda12.1.torch2.2.1"
|
||||
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda12.1.torch2.2.1"
|
||||
ARG TORCHAUDIO_VERSION="2.2.1+cu121"
|
||||
|
||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||
LABEL k2_version=${K2_VERSION}
|
||||
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
|
||||
LABEL github_repo="https://github.com/k2-fsa/icefall"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
vim \
|
||||
libssl-dev \
|
||||
autoconf \
|
||||
automake \
|
||||
bzip2 \
|
||||
ca-certificates \
|
||||
ffmpeg \
|
||||
g++ \
|
||||
gfortran \
|
||||
git \
|
||||
libtool \
|
||||
make \
|
||||
patch \
|
||||
sox \
|
||||
subversion \
|
||||
unzip \
|
||||
valgrind \
|
||||
wget \
|
||||
zlib1g-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install dependencies
|
||||
RUN pip install --no-cache-dir \
|
||||
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
|
||||
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
|
||||
git+https://github.com/lhotse-speech/lhotse \
|
||||
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
|
||||
kaldi_native_io \
|
||||
kaldialign \
|
||||
kaldifst \
|
||||
kaldilm \
|
||||
sentencepiece>=0.1.96 \
|
||||
tensorboard \
|
||||
typeguard \
|
||||
dill \
|
||||
onnx \
|
||||
onnxruntime \
|
||||
onnxmltools \
|
||||
multi_quantization \
|
||||
typeguard \
|
||||
numpy \
|
||||
pytest \
|
||||
graphviz
|
||||
|
||||
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
|
||||
cd /workspace/icefall && \
|
||||
pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
|
||||
|
||||
WORKDIR /workspace/icefall
|
@ -34,6 +34,8 @@ which will give you something like below:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
"torch2.2.1-cuda12.1"
|
||||
"torch2.2.1-cuda11.8"
|
||||
"torch2.2.0-cuda12.1"
|
||||
"torch2.2.0-cuda11.8"
|
||||
"torch2.1.0-cuda12.1"
|
||||
|
225
docs/source/recipes/Finetune/adapter/finetune_adapter.rst
Normal file
225
docs/source/recipes/Finetune/adapter/finetune_adapter.rst
Normal file
@ -0,0 +1,225 @@
|
||||
Finetune from a pre-trained Zipformer model with adapters
|
||||
=========================================================
|
||||
|
||||
This tutorial shows you how to fine-tune a pre-trained **Zipformer**
|
||||
transducer model on a new dataset with adapters.
|
||||
Adapters are compact and efficient module that can be integrated into a pre-trained model
|
||||
to improve the model's performance on a new domain. Adapters are injected
|
||||
between different modules in the well-trained neural network. During training, only the parameters
|
||||
in the adapters will be updated. It achieves competitive performance
|
||||
while requiring much less GPU memory than full fine-tuning. For more details about adapters,
|
||||
please refer to the original `paper <https://arxiv.org/pdf/1902.00751.pdf#/>`_ for more details.
|
||||
|
||||
.. HINT::
|
||||
|
||||
We assume you have read the page :ref:`install icefall` and have setup
|
||||
the environment for ``icefall``.
|
||||
|
||||
.. HINT::
|
||||
|
||||
We recommend you to use a GPU or several GPUs to run this recipe
|
||||
|
||||
For illustration purpose, we fine-tune the Zipformer transducer model
|
||||
pre-trained on `LibriSpeech`_ on the small subset of `GigaSpeech`_. You could use your
|
||||
own data for fine-tuning if you create a manifest for your new dataset.
|
||||
|
||||
Data preparation
|
||||
----------------
|
||||
|
||||
Please follow the instructions in the `GigaSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/gigaspeech/ASR>`_
|
||||
to prepare the fine-tune data used in this tutorial. We only require the small subset in GigaSpeech for this tutorial.
|
||||
|
||||
|
||||
Model preparation
|
||||
-----------------
|
||||
|
||||
We are using the Zipformer model trained on full LibriSpeech (960 hours) as the intialization. The
|
||||
checkpoint of the model can be downloaded via the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
$ cd icefall-asr-librispeech-zipformer-2023-05-15/exp
|
||||
$ git lfs pull --include "pretrained.pt"
|
||||
$ ln -s pretrained.pt epoch-99.pt
|
||||
$ cd ../data/lang_bpe_500
|
||||
$ git lfs pull --include bpe.model
|
||||
$ cd ../../..
|
||||
|
||||
Before fine-tuning, let's test the model's WER on the new domain. The following command performs
|
||||
decoding on the GigaSpeech test sets:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
./zipformer/decode_gigaspeech.py \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--exp-dir icefall-asr-librispeech-zipformer-2023-05-15/exp \
|
||||
--use-averaged-model 0 \
|
||||
--max-duration 1000 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
You should see the following numbers:
|
||||
|
||||
.. code-block::
|
||||
|
||||
For dev, WER of different settings are:
|
||||
greedy_search 20.06 best for dev
|
||||
|
||||
For test, WER of different settings are:
|
||||
greedy_search 19.27 best for test
|
||||
|
||||
|
||||
Fine-tune with adapter
|
||||
----------------------
|
||||
|
||||
We insert 4 adapters with residual connection in each ``Zipformer2EncoderLayer``.
|
||||
The original model parameters remain untouched during training and only the parameters of
|
||||
the adapters are updated. The following command starts a fine-tuning experiment with adapters:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ do_finetune=1
|
||||
$ use_adapters=1
|
||||
$ adapter_dim=8
|
||||
|
||||
$ ./zipformer_adapter/train.py \
|
||||
--world-size 2 \
|
||||
--num-epochs 20 \
|
||||
--start-epoch 1 \
|
||||
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
|
||||
--use-fp16 1 \
|
||||
--base-lr 0.045 \
|
||||
--use-adapters $use_adapters --adapter-dim $adapter_dim \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--do-finetune $do_finetune \
|
||||
--master-port 13022 \
|
||||
--finetune-ckpt icefall-asr-librispeech-zipformer-2023-05-15/exp/pretrained.pt \
|
||||
--max-duration 1000
|
||||
|
||||
The following arguments are related to fine-tuning:
|
||||
|
||||
- ``--do-finetune``
|
||||
If True, do fine-tuning by initializing the model from a pre-trained checkpoint.
|
||||
**Note that if you want to resume your fine-tuning experiment from certain epochs, you
|
||||
need to set this to False.**
|
||||
|
||||
- ``use-adapters``
|
||||
If adapters are used during fine-tuning.
|
||||
|
||||
- ``--adapter-dim``
|
||||
The bottleneck dimension of the adapter module. Typically a small number.
|
||||
|
||||
You should notice that in the training log, the total number of trainale parameters is shown:
|
||||
|
||||
.. code-block::
|
||||
|
||||
2024-02-22 21:22:03,808 INFO [train.py:1277] A total of 761344 trainable parameters (1.148% of the whole model)
|
||||
|
||||
The trainable parameters only makes up 1.15% of the entire model parameters, so the training will be much faster
|
||||
and requires less memory than full fine-tuning.
|
||||
|
||||
|
||||
Decoding
|
||||
--------
|
||||
|
||||
After training, let's test the WERs. To test the WERs on the GigaSpeech set,
|
||||
you can execute the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ epoch=20
|
||||
$ avg=10
|
||||
$ use_adapters=1
|
||||
$ adapter_dim=8
|
||||
|
||||
% ./zipformer/decode.py \
|
||||
--epoch $epoch \
|
||||
--avg $avg \
|
||||
--use-averaged-model 1 \
|
||||
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
|
||||
--max-duration 600 \
|
||||
--use-adapters $use_adapters \
|
||||
--adapter-dim $adapter_dim \
|
||||
--decoding-method greedy_search
|
||||
|
||||
You should see the following numbers:
|
||||
|
||||
.. code-block::
|
||||
|
||||
For dev, WER of different settings are:
|
||||
greedy_search 15.44 best for dev
|
||||
|
||||
For test, WER of different settings are:
|
||||
greedy_search 15.42 best for test
|
||||
|
||||
|
||||
The WER on test set is improved from 19.27 to 15.42, demonstrating the effectiveness of adapters.
|
||||
|
||||
The same model can be used to perform decoding on LibriSpeech test sets. You can deactivate the adapters
|
||||
to keep the same performance of the original model:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ epoch=20
|
||||
$ avg=1
|
||||
$ use_adapters=0
|
||||
$ adapter_dim=8
|
||||
|
||||
% ./zipformer/decode.py \
|
||||
--epoch $epoch \
|
||||
--avg $avg \
|
||||
--use-averaged-model 1 \
|
||||
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
|
||||
--max-duration 600 \
|
||||
--use-adapters $use_adapters \
|
||||
--adapter-dim $adapter_dim \
|
||||
--decoding-method greedy_search
|
||||
|
||||
|
||||
.. code-block::
|
||||
|
||||
For dev, WER of different settings are:
|
||||
greedy_search 2.23 best for test-clean
|
||||
|
||||
For test, WER of different settings are:
|
||||
greedy_search 4.96 best for test-other
|
||||
|
||||
The numbers are the same as reported in `icefall <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md#normal-scaled-model-number-of-model-parameters-65549011-ie-6555-m>`_. So adapter-based
|
||||
fine-tuning is also very flexible as the same model can be used for decoding on the original and target domain.
|
||||
|
||||
|
||||
Export the model
|
||||
----------------
|
||||
|
||||
After training, the model can be exported to ``onnx`` format easily using the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ use_adapters=1
|
||||
$ adapter_dim=16
|
||||
|
||||
$ ./zipformer_adapter/export-onnx.py \
|
||||
--tokens icefall-asr-librispeech-zipformer-2023-05-15/data/lang_bpe_500/tokens.txt \
|
||||
--use-averaged-model 1 \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
|
||||
--use-adapters $use_adapters \
|
||||
--adapter-dim $adapter_dim \
|
||||
--num-encoder-layers "2,2,3,4,3,2" \
|
||||
--downsampling-factor "1,2,4,8,4,2" \
|
||||
--feedforward-dim "512,768,1024,1536,1024,768" \
|
||||
--num-heads "4,4,4,8,4,4" \
|
||||
--encoder-dim "192,256,384,512,384,256" \
|
||||
--query-head-dim 32 \
|
||||
--value-head-dim 12 \
|
||||
--pos-head-dim 4 \
|
||||
--pos-dim 48 \
|
||||
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
||||
--cnn-module-kernel "31,31,15,15,15,31" \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512 \
|
||||
--causal False \
|
||||
--chunk-size "16,32,64,-1" \
|
||||
--left-context-frames "64,128,256,-1"
|
@ -13,3 +13,4 @@ data to improve the performance on new domains.
|
||||
:caption: Table of Contents
|
||||
|
||||
from_supervised/finetune_zipformer
|
||||
adapter/finetune_adapter
|
||||
|
@ -1,4 +1,4 @@
|
||||
VITS
|
||||
VITS-LJSpeech
|
||||
===============
|
||||
|
||||
This tutorial shows you how to train an VITS model
|
||||
@ -13,6 +13,14 @@ with the `LJSpeech <https://keithito.com/LJ-Speech-Dataset/>`_ dataset.
|
||||
The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech <https://arxiv.org/pdf/2106.06103.pdf>`_
|
||||
|
||||
|
||||
Install extra dependencies
|
||||
--------------------------
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html
|
||||
pip install numba espnet_tts_frontend
|
||||
|
||||
Data preparation
|
||||
----------------
|
||||
|
||||
@ -56,7 +64,8 @@ Training
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir vits/exp \
|
||||
--tokens data/tokens.txt
|
||||
--tokens data/tokens.txt \
|
||||
--model-type high \
|
||||
--max-duration 500
|
||||
|
||||
.. note::
|
||||
@ -64,6 +73,11 @@ Training
|
||||
You can adjust the hyper-parameters to control the size of the VITS model and
|
||||
the training configurations. For more details, please run ``./vits/train.py --help``.
|
||||
|
||||
.. warning::
|
||||
|
||||
If you want a model that runs faster on CPU, please use ``--model-type low``
|
||||
or ``--model-type medium``.
|
||||
|
||||
.. note::
|
||||
|
||||
The training can take a long time (usually a couple of days).
|
||||
@ -95,8 +109,8 @@ training part first. It will save the ground-truth and generated wavs to the dir
|
||||
Export models
|
||||
-------------
|
||||
|
||||
Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``:
|
||||
``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``.
|
||||
Currently we only support ONNX model exporting. It will generate one file in the given ``exp-dir``:
|
||||
``vits-epoch-*.onnx``.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
@ -120,4 +134,68 @@ Download pretrained models
|
||||
If you don't want to train from scratch, you can download the pretrained models
|
||||
by visiting the following link:
|
||||
|
||||
- `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2023-11-29>`_
|
||||
- ``--model-type=high``: `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2024-02-28>`_
|
||||
- ``--model-type=medium``: `<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-medium-2024-03-12>`_
|
||||
- ``--model-type=low``: `<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-low-2024-03-12>`_
|
||||
|
||||
Usage in sherpa-onnx
|
||||
--------------------
|
||||
|
||||
The following describes how to test the exported ONNX model in `sherpa-onnx`_.
|
||||
|
||||
.. hint::
|
||||
|
||||
`sherpa-onnx`_ supports different programming languages, e.g., C++, C, Python,
|
||||
Kotlin, Java, Swift, Go, C#, etc. It also supports Android and iOS.
|
||||
|
||||
We only describe how to use pre-built binaries from `sherpa-onnx`_ below.
|
||||
Please refer to `<https://k2-fsa.github.io/sherpa/onnx/>`_
|
||||
for more documentation.
|
||||
|
||||
Install sherpa-onnx
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install sherpa-onnx
|
||||
|
||||
To check that you have installed `sherpa-onnx`_ successfully, please run:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
which sherpa-onnx-offline-tts
|
||||
sherpa-onnx-offline-tts --help
|
||||
|
||||
Download lexicon files
|
||||
^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cd /tmp
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
|
||||
tar xf espeak-ng-data.tar.bz2
|
||||
|
||||
Run sherpa-onnx
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cd egs/ljspeech/TTS
|
||||
|
||||
sherpa-onnx-offline-tts \
|
||||
--vits-model=vits/exp/vits-epoch-1000.onnx \
|
||||
--vits-tokens=data/tokens.txt \
|
||||
--vits-data-dir=/tmp/espeak-ng-data \
|
||||
--num-threads=1 \
|
||||
--output-filename=./high.wav \
|
||||
"Ask not what your country can do for you; ask what you can do for your country."
|
||||
|
||||
.. hint::
|
||||
|
||||
You can also use ``sherpa-onnx-offline-tts-play`` to play the audio
|
||||
as it is generating.
|
||||
|
||||
You should get a file ``high.wav`` after running the above command.
|
||||
|
||||
Congratulations! You have successfully trained and exported a text-to-speech
|
||||
model and run it with `sherpa-onnx`_.
|
||||
|
@ -1,4 +1,4 @@
|
||||
VITS
|
||||
VITS-VCTK
|
||||
===============
|
||||
|
||||
This tutorial shows you how to train an VITS model
|
||||
|
@ -19,7 +19,9 @@ The following table lists the differences among them.
|
||||
| `transducer_stateless_modified` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` |
|
||||
| `transducer_stateless_modified-2` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` + extra data |
|
||||
| `pruned_transducer_stateless3` | Conformer (reworked) | Embedding + Conv1d | pruned RNN-T + reworked model with random combiner + using aidatatang_20zh as extra data|
|
||||
| `pruned_transducer_stateless7` | Zipformer | Embedding | pruned RNN-T + zipformer encoder + stateless decoder with context-size 1 |
|
||||
| `pruned_transducer_stateless7` | Zipformer | Embedding | pruned RNN-T + zipformer encoder + stateless decoder with context-size set to 1 |
|
||||
| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe with context-size set to 1 |
|
||||
|
||||
|
||||
The decoder in `transducer_stateless` is modified from the paper
|
||||
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
|
||||
|
@ -360,7 +360,7 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
||||
log "Stage 11: Train RNN LM model"
|
||||
log "Stage 12: Train RNN LM model"
|
||||
python ../../../icefall/rnn_lm/train.py \
|
||||
--start-epoch 0 \
|
||||
--world-size 1 \
|
||||
|
@ -89,6 +89,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -881,9 +882,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error()
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
||||
|
@ -85,6 +85,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
@ -878,9 +879,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
||||
|
@ -78,6 +78,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -871,9 +872,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -250,7 +250,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=1,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
@ -78,6 +78,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
@ -882,9 +883,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -78,6 +78,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
@ -881,9 +882,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -86,6 +86,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
@ -985,9 +986,7 @@ def train_one_epoch(
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
|
@ -83,6 +83,7 @@ from icefall.checkpoint import (
|
||||
update_averaged_model,
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -570,9 +571,7 @@ def train_one_epoch(
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
|
@ -29,7 +29,14 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
@ -42,10 +49,12 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
def compute_fbank_aishell2(
|
||||
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
num_jobs = min(8, os.cpu_count())
|
||||
|
||||
dataset_parts = (
|
||||
"train",
|
||||
@ -68,7 +77,11 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
list(manifests.keys()),
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
@ -82,7 +95,7 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition and perturb_speed:
|
||||
logging.info(f"Doing speed perturb")
|
||||
logging.info("Doing speed perturb")
|
||||
cut_set = (
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
@ -111,7 +124,12 @@ def get_args():
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-fbank",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -122,5 +140,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_aishell2(
|
||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
perturb_speed=args.perturb_speed,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
@ -108,6 +108,16 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
whisper_mel_bins=80
|
||||
if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then
|
||||
log "Stage 30: Compute whisper fbank for aishell2"
|
||||
if [ ! -f data/fbank/.aishell2.whisper.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_aishell2.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||
touch data/fbank/.aishell2.whisper.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Compute fbank for musan"
|
||||
if [ ! -f data/fbank/.msuan.done ]; then
|
||||
|
@ -29,7 +29,14 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
@ -42,10 +49,12 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
def compute_fbank_aishell4(
|
||||
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests/aishell4")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
num_jobs = min(8, os.cpu_count())
|
||||
|
||||
dataset_parts = (
|
||||
"train_S",
|
||||
@ -70,6 +79,11 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
@ -84,7 +98,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition and perturb_speed:
|
||||
logging.info(f"Doing speed perturb")
|
||||
logging.info("Doing speed perturb")
|
||||
cut_set = (
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
@ -95,7 +109,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
# when an executor is specified, make more partitions
|
||||
num_jobs=num_jobs if ex is None else 80,
|
||||
executor=ex,
|
||||
storage_type=ChunkedLilcomHdf5Writer,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
|
||||
logging.info("About splitting cuts into smaller chunks")
|
||||
@ -121,7 +135,12 @@ def get_args():
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-fbank",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -132,5 +151,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_aishell4(
|
||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
perturb_speed=args.perturb_speed,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
@ -6,7 +6,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
set -eou pipefail
|
||||
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
stop_stage=7
|
||||
perturb_speed=true
|
||||
|
||||
|
||||
@ -76,11 +76,21 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Process aishell4"
|
||||
log "Stage 2: Compute fbank for aishell4"
|
||||
if [ ! -f data/fbank/aishell4/.fbank.done ]; then
|
||||
mkdir -p data/fbank/aishell4
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed}
|
||||
touch data/fbank/aishell4/.fbank.done
|
||||
touch data/fbank/.fbank.done
|
||||
fi
|
||||
fi
|
||||
|
||||
whisper_mel_bins=80
|
||||
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
|
||||
log "Stage 20: Compute whisper fbank for aishell4"
|
||||
if [ ! -f data/fbank/aishell4/.fbank.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||
touch data/fbank/.fbank.done
|
||||
fi
|
||||
fi
|
||||
|
||||
@ -106,16 +116,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Compute fbank for aishell4"
|
||||
if [ ! -f data/fbank/.aishell4.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed}
|
||||
touch data/fbank/.aishell4.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Prepare char based lang"
|
||||
log "Stage 5: Prepare char based lang"
|
||||
lang_char_dir=data/lang_char
|
||||
mkdir -p $lang_char_dir
|
||||
|
||||
|
@ -29,7 +29,14 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
@ -42,10 +49,12 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||
def compute_fbank_alimeeting(
|
||||
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests/alimeeting")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
num_jobs = min(8, os.cpu_count())
|
||||
|
||||
dataset_parts = (
|
||||
"train",
|
||||
@ -53,7 +62,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
|
||||
"test",
|
||||
)
|
||||
|
||||
prefix = "alimeeting"
|
||||
prefix = "alimeeting-far"
|
||||
suffix = "jsonl.gz"
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=dataset_parts,
|
||||
@ -70,6 +79,11 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
@ -83,7 +97,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition and perturb_speed:
|
||||
logging.info(f"Doing speed perturb")
|
||||
logging.info("Doing speed perturb")
|
||||
cut_set = (
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
@ -121,7 +135,12 @@ def get_args():
|
||||
default=False,
|
||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-fbank",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use the Whisper Fbank feature extractor. Default: False.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -132,5 +151,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_alimeeting(
|
||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
perturb_speed=args.perturb_speed,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
@ -6,7 +6,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
set -eou pipefail
|
||||
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
stop_stage=7
|
||||
perturb_speed=true
|
||||
|
||||
# We assume dl_dir (download dir) contains the following
|
||||
@ -15,7 +15,7 @@ perturb_speed=true
|
||||
#
|
||||
# - $dl_dir/alimeeting
|
||||
# This directory contains the following files downloaded from
|
||||
# https://openslr.org/62/
|
||||
# https://openslr.org/119/
|
||||
#
|
||||
# - Train_Ali_far.tar.gz
|
||||
# - Train_Ali_near.tar.gz
|
||||
@ -66,10 +66,21 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Process alimeeting"
|
||||
if [ ! -f data/fbank/alimeeting/.fbank.done ]; then
|
||||
mkdir -p data/fbank/alimeeting
|
||||
log "Stage 2: compute fbank for alimeeting"
|
||||
if [ ! -f data/fbank/.fbank.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed}
|
||||
touch data/fbank/.fbank.done
|
||||
fi
|
||||
fi
|
||||
|
||||
whisper_mel_bins=80
|
||||
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
|
||||
log "Stage 20: compute whisper fbank for alimeeting"
|
||||
if [ ! -f data/fbank/.fbank.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||
touch data/fbank/.fbank.done
|
||||
fi
|
||||
fi
|
||||
|
||||
@ -95,16 +106,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Compute fbank for alimeeting"
|
||||
if [ ! -f data/fbank/.alimeeting.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_alimeeting.py --perturb-speed True
|
||||
touch data/fbank/.alimeeting.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Prepare char based lang"
|
||||
log "Stage 5: Prepare char based lang"
|
||||
lang_char_dir=data/lang_char
|
||||
mkdir -p $lang_char_dir
|
||||
|
||||
|
@ -12,7 +12,7 @@ use_gss=true # Use GSS-based enhancement with MDM setting
|
||||
#
|
||||
# - $dl_dir/alimeeting
|
||||
# This directory contains the following files downloaded from
|
||||
# https://openslr.org/62/
|
||||
# https://openslr.org/119/
|
||||
#
|
||||
# - Train_Ali_far.tar.gz
|
||||
# - Train_Ali_near.tar.gz
|
||||
|
@ -70,6 +70,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
@ -851,9 +852,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -69,6 +69,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
@ -842,9 +843,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -75,6 +75,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -1138,9 +1139,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -75,6 +75,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -1129,9 +1130,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -1 +0,0 @@
|
||||
../../../librispeech/ASR/local/compile_hlg.py
|
168
egs/commonvoice/ASR/local/compile_hlg.py
Executable file
168
egs/commonvoice/ASR/local/compile_hlg.py
Executable file
@ -0,0 +1,168 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Zengrui Jin,)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This script takes as input lang_dir and generates HLG from
|
||||
|
||||
- H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
|
||||
- L, the lexicon, built from lang_dir/L_disambig.pt
|
||||
|
||||
Caution: We use a lexicon that contains disambiguation symbols
|
||||
|
||||
- G, the LM, built from data/lm/G_n_gram.fst.txt
|
||||
|
||||
The generated HLG is saved in $lang_dir/HLG.pt
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lm",
|
||||
type=str,
|
||||
default="G_3_gram",
|
||||
help="""Stem name for LM used in HLG compiling.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
|
||||
lm:
|
||||
The language stem base name.
|
||||
|
||||
Return:
|
||||
An FSA representing HLG.
|
||||
"""
|
||||
lexicon = Lexicon(lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
|
||||
H = k2.ctc_topo(max_token_id)
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
|
||||
if Path(f"{lang_dir}/lm/{lm}.pt").is_file():
|
||||
logging.info(f"Loading pre-compiled {lm}")
|
||||
d = torch.load(f"{lang_dir}/lm/{lm}.pt")
|
||||
G = k2.Fsa.from_dict(d)
|
||||
else:
|
||||
logging.info(f"Loading {lm}.fst.txt")
|
||||
with open(f"{lang_dir}/lm/{lm}.fst.txt") as f:
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
torch.save(G.as_dict(), f"{lang_dir}/lm/{lm}.pt")
|
||||
|
||||
first_token_disambig_id = lexicon.token_table["#0"]
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
|
||||
L = k2.arc_sort(L)
|
||||
G = k2.arc_sort(G)
|
||||
|
||||
logging.info("Intersecting L and G")
|
||||
LG = k2.compose(L, G)
|
||||
logging.info(f"LG shape: {LG.shape}")
|
||||
|
||||
logging.info("Connecting LG")
|
||||
LG = k2.connect(LG)
|
||||
logging.info(f"LG shape after k2.connect: {LG.shape}")
|
||||
|
||||
logging.info(type(LG.aux_labels))
|
||||
logging.info("Determinizing LG")
|
||||
|
||||
LG = k2.determinize(LG)
|
||||
logging.info(type(LG.aux_labels))
|
||||
|
||||
logging.info("Connecting LG after k2.determinize")
|
||||
LG = k2.connect(LG)
|
||||
|
||||
logging.info("Removing disambiguation symbols on LG")
|
||||
|
||||
# LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||
# see https://github.com/k2-fsa/k2/pull/1140
|
||||
labels = LG.labels
|
||||
labels[labels >= first_token_disambig_id] = 0
|
||||
LG.labels = labels
|
||||
|
||||
assert isinstance(LG.aux_labels, k2.RaggedTensor)
|
||||
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
|
||||
|
||||
LG = k2.remove_epsilon(LG)
|
||||
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
|
||||
|
||||
LG = k2.connect(LG)
|
||||
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
|
||||
|
||||
logging.info("Arc sorting LG")
|
||||
LG = k2.arc_sort(LG)
|
||||
|
||||
logging.info("Composing H and LG")
|
||||
# CAUTION: The name of the inner_labels is fixed
|
||||
# to `tokens`. If you want to change it, please
|
||||
# also change other places in icefall that are using
|
||||
# it.
|
||||
HLG = k2.compose(H, LG, inner_labels="tokens")
|
||||
|
||||
logging.info("Connecting LG")
|
||||
HLG = k2.connect(HLG)
|
||||
|
||||
logging.info("Arc sorting LG")
|
||||
HLG = k2.arc_sort(HLG)
|
||||
logging.info(f"HLG.shape: {HLG.shape}")
|
||||
|
||||
return HLG
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
|
||||
if (lang_dir / "HLG.pt").is_file():
|
||||
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
|
||||
return
|
||||
|
||||
logging.info(f"Processing {lang_dir}")
|
||||
|
||||
HLG = compile_HLG(lang_dir, args.lm)
|
||||
logging.info(f"Saving HLG.pt to {lang_dir}")
|
||||
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
@ -1 +0,0 @@
|
||||
../../../librispeech/ASR/local/compile_lg.py
|
149
egs/commonvoice/ASR/local/compile_lg.py
Executable file
149
egs/commonvoice/ASR/local/compile_lg.py
Executable file
@ -0,0 +1,149 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Kang Wei,
|
||||
# Zengrui Jin,)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This script takes as input lang_dir and generates LG from
|
||||
|
||||
- L, the lexicon, built from lang_dir/L_disambig.pt
|
||||
|
||||
Caution: We use a lexicon that contains disambiguation symbols
|
||||
|
||||
- G, the LM, built from lang_dir/lm/G_3_gram.fst.txt
|
||||
|
||||
The generated LG is saved in $lang_dir/LG.pt
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lm",
|
||||
type=str,
|
||||
default="G_3_gram",
|
||||
help="""Stem name for LM used in HLG compiling.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
|
||||
|
||||
Return:
|
||||
An FSA representing LG.
|
||||
"""
|
||||
lexicon = Lexicon(lang_dir)
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
|
||||
if Path(f"{lang_dir}/lm/{lm}.pt").is_file():
|
||||
logging.info(f"Loading pre-compiled {lm}")
|
||||
d = torch.load(f"{lang_dir}/lm/{lm}.pt")
|
||||
G = k2.Fsa.from_dict(d)
|
||||
else:
|
||||
logging.info(f"Loading {lm}.fst.txt")
|
||||
with open(f"{lang_dir}/lm/{lm}.fst.txt") as f:
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
torch.save(G.as_dict(), f"{lang_dir}/lm/{lm}.pt")
|
||||
|
||||
first_token_disambig_id = lexicon.token_table["#0"]
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
|
||||
L = k2.arc_sort(L)
|
||||
G = k2.arc_sort(G)
|
||||
|
||||
logging.info("Intersecting L and G")
|
||||
LG = k2.compose(L, G)
|
||||
logging.info(f"LG shape: {LG.shape}")
|
||||
|
||||
logging.info("Connecting LG")
|
||||
LG = k2.connect(LG)
|
||||
logging.info(f"LG shape after k2.connect: {LG.shape}")
|
||||
|
||||
logging.info(type(LG.aux_labels))
|
||||
logging.info("Determinizing LG")
|
||||
|
||||
LG = k2.determinize(LG, k2.DeterminizeWeightPushingType.kLogWeightPushing)
|
||||
logging.info(type(LG.aux_labels))
|
||||
|
||||
logging.info("Connecting LG after k2.determinize")
|
||||
LG = k2.connect(LG)
|
||||
|
||||
logging.info("Removing disambiguation symbols on LG")
|
||||
|
||||
# LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||
# see https://github.com/k2-fsa/k2/pull/1140
|
||||
labels = LG.labels
|
||||
labels[labels >= first_token_disambig_id] = 0
|
||||
LG.labels = labels
|
||||
|
||||
assert isinstance(LG.aux_labels, k2.RaggedTensor)
|
||||
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
|
||||
|
||||
LG = k2.remove_epsilon(LG)
|
||||
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
|
||||
|
||||
LG = k2.connect(LG)
|
||||
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
|
||||
|
||||
logging.info("Arc sorting LG")
|
||||
LG = k2.arc_sort(LG)
|
||||
|
||||
return LG
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
|
||||
if (lang_dir / "LG.pt").is_file():
|
||||
logging.info(f"{lang_dir}/LG.pt already exists - skipping")
|
||||
return
|
||||
|
||||
logging.info(f"Processing {lang_dir}")
|
||||
|
||||
LG = compile_LG(lang_dir, args.lm)
|
||||
logging.info(f"Saving LG.pt to {lang_dir}")
|
||||
torch.save(LG.as_dict(), f"{lang_dir}/LG.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
@ -48,8 +48,27 @@ def normalize_text(utt: str, language: str) -> str:
|
||||
utt = re.sub("’", "'", utt)
|
||||
if language == "en":
|
||||
return re.sub(r"[^a-zA-Z\s]", "", utt).upper()
|
||||
if language == "fr":
|
||||
elif language == "fr":
|
||||
return re.sub(r"[^A-ZÀÂÆÇÉÈÊËÎÏÔŒÙÛÜ' ]", "", utt).upper()
|
||||
elif language == "pl":
|
||||
return re.sub(r"[^a-ząćęłńóśźżA-ZĄĆĘŁŃÓŚŹŻ' ]", "", utt).upper()
|
||||
elif language == "yue":
|
||||
return (
|
||||
utt.replace(" ", "")
|
||||
.replace(",", "")
|
||||
.replace("。", " ")
|
||||
.replace("?", "")
|
||||
.replace("!", "")
|
||||
.replace("?", "")
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"""
|
||||
Text normalization not implemented for language: {language},
|
||||
please consider implementing it in the local/preprocess_commonvoice.py
|
||||
or raise an issue on GitHub to request it.
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def preprocess_commonvoice(
|
||||
|
@ -381,9 +381,11 @@ class CommonVoiceAsrDataModule:
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
input_strategy=(
|
||||
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)(),
|
||||
else eval(self.args.input_strategy)()
|
||||
),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
|
@ -79,10 +79,10 @@ It will generate the following 3 files inside $repo/exp:
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from icefall import is_module_available
|
||||
import torch
|
||||
from onnx_pretrained import OnnxModel
|
||||
|
||||
import torch
|
||||
from icefall import is_module_available
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -79,6 +79,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -871,9 +872,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -31,7 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
@ -232,7 +232,7 @@ class CommonVoiceAsrDataModule:
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
@ -315,8 +315,8 @@ class CommonVoiceAsrDataModule:
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
@ -383,9 +383,11 @@ class CommonVoiceAsrDataModule:
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
input_strategy=(
|
||||
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)(),
|
||||
else eval(self.args.input_strategy)()
|
||||
),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
|
@ -889,9 +889,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise RuntimeError(f", exiting: {cur_grad_scale}")
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -81,6 +81,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -965,9 +966,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -78,6 +78,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
@ -888,9 +889,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -81,6 +81,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
@ -909,9 +910,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -70,9 +70,9 @@ import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from tokenizer import Tokenizer
|
||||
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -81,6 +81,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
@ -908,9 +909,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -23,6 +23,7 @@ from pathlib import Path
|
||||
|
||||
from lhotse import CutSet, SupervisionSegment
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
# Similar text filtering and normalization procedure as in:
|
||||
|
@ -76,6 +76,7 @@ from beam_search import (
|
||||
)
|
||||
from gigaspeech_scoring import asr_text_post_processing
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
|
@ -88,7 +88,7 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import GigaSpeechAsrDataModule
|
||||
from train import add_model_arguments, get_params, get_model
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -51,7 +51,7 @@ from streaming_beam_search import (
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_model
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -89,6 +89,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -1031,9 +1032,7 @@ def train_one_epoch(
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
|
@ -42,12 +42,10 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import GigaSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
keywords_search,
|
||||
)
|
||||
from beam_search import keywords_search
|
||||
from lhotse.cut import Cut
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from lhotse.cut import Cut
|
||||
from icefall import ContextGraph
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -76,25 +76,6 @@ from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.checkpoint import remove_checkpoints
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.checkpoint import (
|
||||
save_checkpoint_with_global_batch_idx,
|
||||
update_averaged_model,
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
from train import (
|
||||
add_model_arguments,
|
||||
add_training_arguments,
|
||||
@ -110,6 +91,25 @@ from train import (
|
||||
set_batch_count,
|
||||
)
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.checkpoint import remove_checkpoints
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.checkpoint import (
|
||||
save_checkpoint_with_global_batch_idx,
|
||||
update_averaged_model,
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
get_parameter_groups_with_lrs,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
|
||||
@ -372,9 +372,7 @@ def train_one_epoch(
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
|
@ -89,6 +89,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -1034,9 +1035,7 @@ def train_one_epoch(
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
|
@ -85,6 +85,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -1169,9 +1170,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -81,6 +81,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -1056,9 +1057,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -232,7 +232,7 @@ class LibriHeavyAsrDataModule:
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
@ -93,6 +93,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -1036,9 +1037,7 @@ def train_one_epoch(
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
|
@ -425,9 +425,11 @@ class LibriHeavyAsrDataModule:
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
input_strategy=(
|
||||
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
else PrecomputedFeatures()
|
||||
),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
|
@ -103,6 +103,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -1051,9 +1052,7 @@ def train_one_epoch(
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
|
@ -117,6 +117,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -855,9 +856,9 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
decoder_dim=params.decoder_dim,
|
||||
joiner_dim=params.joiner_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
context_dim=4 * 768
|
||||
if params.context_injection
|
||||
else -1, # the output dim of text encoder
|
||||
context_dim=(
|
||||
4 * 768 if params.context_injection else -1
|
||||
), # the output dim of text encoder
|
||||
context_injection=params.context_injection,
|
||||
)
|
||||
return joiner
|
||||
@ -1398,9 +1399,7 @@ def train_one_epoch(
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
|
@ -35,8 +35,7 @@ The following table lists the differences among them.
|
||||
| `lstm_transducer_stateless2` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gigaspeech (multi-dataset setup) |
|
||||
| `lstm_transducer_stateless3` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gradient filter + delay penalty |
|
||||
| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe |
|
||||
| `zipformer_adapter` | Upgraded Zipformer | Embedding + Conv1d | Finetune `zipformer` with adapter |
|
||||
| `zipformer_lora` | Upgraded Zipformer | Embedding + Conv1d | Finetune `zipformer` with LoRA |
|
||||
| `zipformer_adapter` | Upgraded Zipformer | Embedding + Conv1d | It supports domain adaptation of Zipformer using parameter efficient adapters |
|
||||
|
||||
The decoder in `transducer_stateless` is modified from the paper
|
||||
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
|
||||
|
@ -24,8 +24,7 @@ To run this file, do:
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from train import get_params, get_ctc_model
|
||||
from train import get_ctc_model, get_params
|
||||
|
||||
|
||||
def test_model():
|
||||
|
@ -59,9 +59,9 @@ import onnx
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from decoder import Decoder
|
||||
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
|
||||
from emformer import Emformer
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -39,7 +39,7 @@ Usage of this script:
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
@ -47,7 +47,6 @@ import torch
|
||||
import torchaudio
|
||||
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from typing import Optional, List
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -31,28 +31,28 @@ https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stat
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import torch.multiprocessing as mp
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import AsrDataModule
|
||||
from beam_search import (
|
||||
fast_beam_search_one_best,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from icefall.utils import AttributeDict, convert_timestamp, setup_logger
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.supervision import AlignmentItem
|
||||
from lhotse.serialization import SequentialJsonlWriter
|
||||
from lhotse.supervision import AlignmentItem
|
||||
|
||||
from icefall.utils import AttributeDict, convert_timestamp, setup_logger
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -73,12 +73,11 @@ It will generate the following 3 files inside $repo/exp:
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from onnx_pretrained import OnnxModel
|
||||
|
||||
from icefall import is_module_available
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -22,11 +22,12 @@ Usage: ./pruned_transducer_stateless/my_profile.py
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from train import add_model_arguments, get_encoder_model, get_params
|
||||
|
||||
from icefall.profiler import get_model_profile
|
||||
from train import get_encoder_model, add_model_arguments, get_params
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -75,8 +75,7 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
|
||||
from onnx_pretrained import greedy_search, OnnxModel
|
||||
from onnx_pretrained import OnnxModel, greedy_search
|
||||
|
||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||
|
||||
|
@ -78,10 +78,10 @@ It will generate the following 3 files inside $repo/exp:
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from icefall import is_module_available
|
||||
import torch
|
||||
from onnx_pretrained import OnnxModel
|
||||
|
||||
import torch
|
||||
from icefall import is_module_available
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -76,8 +76,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import AsrDataModule
|
||||
from librispeech import LibriSpeech
|
||||
|
||||
from onnx_pretrained import greedy_search, OnnxModel
|
||||
from onnx_pretrained import OnnxModel, greedy_search
|
||||
|
||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||
|
||||
|
@ -22,15 +22,15 @@ Usage: ./pruned_transducer_stateless4/my_profile.py
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
|
||||
from typing import Tuple
|
||||
from scaling import BasicNorm, DoubleSwish
|
||||
from torch import Tensor, nn
|
||||
from train import add_model_arguments, get_encoder_model, get_joiner_model, get_params
|
||||
|
||||
from icefall.profiler import get_model_profile
|
||||
from scaling import BasicNorm, DoubleSwish
|
||||
from train import get_encoder_model, get_joiner_model, add_model_arguments, get_params
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -82,8 +82,7 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
|
||||
from onnx_pretrained import greedy_search, OnnxModel
|
||||
from onnx_pretrained import OnnxModel, greedy_search
|
||||
|
||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||
|
||||
|
@ -20,7 +20,6 @@ from typing import List
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||
|
||||
# The force alignment problem can be formulated as finding
|
||||
|
@ -107,9 +107,6 @@ import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from gigaspeech import GigaSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
@ -120,6 +117,9 @@ from beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
|
||||
# from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from gigaspeech import GigaSpeechAsrDataModule
|
||||
from gigaspeech_scoring import asr_text_post_processing
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
|
@ -80,6 +80,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -976,9 +977,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -65,16 +65,15 @@ from typing import Dict, List
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.utils import str2bool
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -22,15 +22,15 @@ Usage: ./pruned_transducer_stateless7/my_profile.py
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
|
||||
from typing import Tuple
|
||||
from scaling import BasicNorm, DoubleSwish
|
||||
from torch import Tensor, nn
|
||||
from train import add_model_arguments, get_encoder_model, get_joiner_model, get_params
|
||||
|
||||
from icefall.profiler import get_model_profile
|
||||
from scaling import BasicNorm, DoubleSwish
|
||||
from train import get_encoder_model, get_joiner_model, add_model_arguments, get_params
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -75,8 +75,7 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
|
||||
from onnx_pretrained import greedy_search, OnnxModel
|
||||
from onnx_pretrained import OnnxModel, greedy_search
|
||||
|
||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||
|
||||
|
@ -24,7 +24,6 @@ To run this file, do:
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
|
@ -81,6 +81,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -878,9 +879,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -81,6 +81,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -902,9 +903,7 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 0.01:
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
|
@ -118,8 +118,8 @@ from beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -18,10 +18,7 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling import (
|
||||
ActivationBalancer,
|
||||
ScaledConv1d,
|
||||
)
|
||||
from scaling import ActivationBalancer, ScaledConv1d
|
||||
|
||||
|
||||
class LConv(nn.Module):
|
||||
|
@ -52,7 +52,7 @@ import onnxruntime as ort
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
|
||||
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user