resolve conflict

This commit is contained in:
marcoyang 2024-03-14 18:22:36 +08:00
commit 390f01653f
285 changed files with 21770 additions and 625 deletions

1
.github/scripts/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
piper_phonemize.html

View File

@ -0,0 +1,29 @@
#!/usr/bin/env python3
def main():
prefix = (
"https://github.com/csukuangfj/piper-phonemize/releases/download/2023.12.5/"
)
files = [
"piper_phonemize-1.2.0-cp310-cp310-macosx_10_14_x86_64.whl",
"piper_phonemize-1.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
"piper_phonemize-1.2.0-cp311-cp311-macosx_10_14_x86_64.whl",
"piper_phonemize-1.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
"piper_phonemize-1.2.0-cp312-cp312-macosx_10_14_x86_64.whl",
"piper_phonemize-1.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
"piper_phonemize-1.2.0-cp37-cp37m-macosx_10_14_x86_64.whl",
"piper_phonemize-1.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
"piper_phonemize-1.2.0-cp38-cp38-macosx_10_14_x86_64.whl",
"piper_phonemize-1.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
"piper_phonemize-1.2.0-cp39-cp39-macosx_10_14_x86_64.whl",
"piper_phonemize-1.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
]
with open("piper_phonemize.html", "w") as f:
for file in files:
url = prefix + file
f.write(f'<a href="{url}">{file}</a><br/>\n')
if __name__ == "__main__":
main()

View File

@ -15,9 +15,9 @@ function prepare_data() {
# cause OOM error for CI later. # cause OOM error for CI later.
mkdir -p download/lm mkdir -p download/lm
pushd download/lm pushd download/lm
wget -q http://www.openslr.org/resources/11/librispeech-vocab.txt wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-lm-norm.txt.gz
wget -q http://www.openslr.org/resources/11/librispeech-lexicon.txt wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-lexicon.txt
wget -q http://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-vocab.txt
ls -lh ls -lh
gunzip librispeech-lm-norm.txt.gz gunzip librispeech-lm-norm.txt.gz

157
.github/scripts/ljspeech/TTS/run.sh vendored Executable file
View File

@ -0,0 +1,157 @@
#!/usr/bin/env bash
set -ex
python3 -m pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html
python3 -m pip install espnet_tts_frontend
python3 -m pip install numba
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/ljspeech/TTS
sed -i.bak s/600/8/g ./prepare.sh
sed -i.bak s/"first 100"/"first 3"/g ./prepare.sh
sed -i.bak s/500/5/g ./prepare.sh
git diff
function prepare_data() {
# We have created a subset of the data for testing
#
mkdir download
pushd download
wget -q https://huggingface.co/csukuangfj/ljspeech-subset-for-ci-test/resolve/main/LJSpeech-1.1.tar.bz2
tar xvf LJSpeech-1.1.tar.bz2
popd
./prepare.sh
tree .
}
function train() {
pushd ./vits
sed -i.bak s/200/3/g ./train.py
git diff .
popd
for t in low medium high; do
./vits/train.py \
--exp-dir vits/exp-$t \
--model-type $t \
--num-epochs 1 \
--save-every-n 1 \
--num-buckets 2 \
--tokens data/tokens.txt \
--max-duration 20
ls -lh vits/exp-$t
done
}
function infer() {
for t in low medium high; do
./vits/infer.py \
--num-buckets 2 \
--model-type $t \
--epoch 1 \
--exp-dir ./vits/exp-$t \
--tokens data/tokens.txt \
--max-duration 20
done
}
function export_onnx() {
for t in low medium high; do
./vits/export-onnx.py \
--model-type $t \
--epoch 1 \
--exp-dir ./vits/exp-$t \
--tokens data/tokens.txt
ls -lh vits/exp-$t/
done
}
function test_medium() {
git clone https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-medium-2024-03-12
./vits/export-onnx.py \
--model-type medium \
--epoch 820 \
--exp-dir ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp \
--tokens ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt
ls -lh ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp
./vits/test_onnx.py \
--model-filename ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp/vits-epoch-820.onnx \
--tokens ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt \
--output-filename /icefall/test-medium.wav
ls -lh /icefall/test-medium.wav
d=/icefall/vits-icefall-en_US-ljspeech-medium
mkdir $d
cp -v ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt $d/
cp -v ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp/vits-epoch-820.onnx $d/model.onnx
rm -rf icefall-tts-ljspeech-vits-medium-2024-03-12
pushd $d
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
tar xf espeak-ng-data.tar.bz2
rm espeak-ng-data.tar.bz2
cd ..
tar cjf vits-icefall-en_US-ljspeech-medium.tar.bz2 vits-icefall-en_US-ljspeech-medium
rm -rf vits-icefall-en_US-ljspeech-medium
ls -lh *.tar.bz2
popd
}
function test_low() {
git clone https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-low-2024-03-12
./vits/export-onnx.py \
--model-type low \
--epoch 1600 \
--exp-dir ./icefall-tts-ljspeech-vits-low-2024-03-12/exp \
--tokens ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt
ls -lh ./icefall-tts-ljspeech-vits-low-2024-03-12/exp
./vits/test_onnx.py \
--model-filename ./icefall-tts-ljspeech-vits-low-2024-03-12/exp/vits-epoch-1600.onnx \
--tokens ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt \
--output-filename /icefall/test-low.wav
ls -lh /icefall/test-low.wav
d=/icefall/vits-icefall-en_US-ljspeech-low
mkdir $d
cp -v ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt $d/
cp -v ./icefall-tts-ljspeech-vits-low-2024-03-12/exp/vits-epoch-1600.onnx $d/model.onnx
rm -rf icefall-tts-ljspeech-vits-low-2024-03-12
pushd $d
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
tar xf espeak-ng-data.tar.bz2
rm espeak-ng-data.tar.bz2
cd ..
tar cjf vits-icefall-en_US-ljspeech-low.tar.bz2 vits-icefall-en_US-ljspeech-low
rm -rf vits-icefall-en_US-ljspeech-low
ls -lh *.tar.bz2
popd
}
prepare_data
train
infer
export_onnx
rm -rf vits/exp-{low,medium,high}
test_medium
test_low

View File

@ -56,11 +56,14 @@ jobs:
- name: Build doc - name: Build doc
shell: bash shell: bash
run: | run: |
.github/scripts/generate-piper-phonemize-page.py
cd docs cd docs
python3 -m pip install -r ./requirements.txt python3 -m pip install -r ./requirements.txt
make html make html
touch build/html/.nojekyll touch build/html/.nojekyll
cp -v ../piper_phonemize.html ./build/html/
- name: Deploy - name: Deploy
uses: peaceiris/actions-gh-pages@v3 uses: peaceiris/actions-gh-pages@v3
with: with:

View File

@ -16,7 +16,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
image: ["torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] image: ["torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
steps: steps:
# refer to https://github.com/actions/checkout # refer to https://github.com/actions/checkout

102
.github/workflows/ljspeech.yml vendored Normal file
View File

@ -0,0 +1,102 @@
name: ljspeech
on:
push:
branches:
- master
pull_request:
branches:
- master
workflow_dispatch:
concurrency:
group: ljspeech-${{ github.ref }}
cancel-in-progress: true
jobs:
generate_build_matrix:
if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa'
# see https://github.com/pytorch/pytorch/pull/50633
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Generating build matrix
id: set-matrix
run: |
# outputting for debugging purposes
python ./.github/scripts/docker/generate_build_matrix.py
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py)
echo "::set-output name=matrix::${MATRIX}"
ljspeech:
needs: generate_build_matrix
name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Free space
shell: bash
run: |
ls -lh
df -h
rm -rf /opt/hostedtoolcache
df -h
echo "pwd: $PWD"
echo "github.workspace ${{ github.workspace }}"
- name: Run tests
uses: addnab/docker-run-action@v3
with:
image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }}
options: |
--volume ${{ github.workspace }}/:/icefall
shell: bash
run: |
export PYTHONPATH=/icefall:$PYTHONPATH
cd /icefall
git config --global --add safe.directory /icefall
.github/scripts/ljspeech/TTS/run.sh
- name: display files
shell: bash
run: |
ls -lh
- uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0'
with:
name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }}
path: ./*.wav
- uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0'
with:
name: generated-models-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}
path: ./*.wav
- name: Release exported onnx models
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0'
uses: svenstaro/upload-release-action@v2
with:
file_glob: true
overwrite: true
file: vits-icefall-*.tar.bz2
repo_name: k2-fsa/sherpa-onnx
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
tag: tts-models

View File

@ -14,7 +14,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
image: ["torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] image: ["torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
steps: steps:
# refer to https://github.com/actions/checkout # refer to https://github.com/actions/checkout
- uses: actions/checkout@v2 - uses: actions/checkout@v2

View File

@ -49,7 +49,7 @@ jobs:
- name: Install Python dependencies - name: Install Python dependencies
run: | run: |
python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0 python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0 isort==5.10.1
# Click issue fixed in https://github.com/psf/black/pull/2966 # Click issue fixed in https://github.com/psf/black/pull/2966
- name: Run flake8 - name: Run flake8
@ -67,3 +67,9 @@ jobs:
working-directory: ${{github.workspace}} working-directory: ${{github.workspace}}
run: | run: |
black --check --diff . black --check --diff .
- name: Run isort
shell: bash
working-directory: ${{github.workspace}}
run: |
isort --check --diff .

View File

@ -26,7 +26,7 @@ repos:
# E121,E123,E126,E226,E24,E704,W503,W504 # E121,E123,E126,E226,E24,E704,W503,W504
- repo: https://github.com/pycqa/isort - repo: https://github.com/pycqa/isort
rev: 5.11.5 rev: 5.10.1
hooks: hooks:
- id: isort - id: isort
args: ["--profile=black"] args: ["--profile=black"]

View File

@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive ARG DEBIAN_FRONTEND=noninteractive
# python 3.7 # python 3.7
ARG K2_VERSION="1.24.4.dev20240211+cuda11.3.torch1.12.1" ARG K2_VERSION="1.24.4.dev20240223+cuda11.3.torch1.12.1"
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.3.torch1.12.1" ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.3.torch1.12.1"
ARG TORCHAUDIO_VERSION="0.12.1+cu113" ARG TORCHAUDIO_VERSION="0.12.1+cu113"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>" LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"

View File

@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive ARG DEBIAN_FRONTEND=noninteractive
# python 3.9 # python 3.9
ARG K2_VERSION="1.24.4.dev20240211+cuda11.6.torch1.13.0" ARG K2_VERSION="1.24.4.dev20240223+cuda11.6.torch1.13.0"
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.6.torch1.13.0" ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.6.torch1.13.0"
ARG TORCHAUDIO_VERSION="0.13.0+cu116" ARG TORCHAUDIO_VERSION="0.13.0+cu116"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>" LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"

View File

@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive ARG DEBIAN_FRONTEND=noninteractive
# python 3.7 # python 3.7
ARG K2_VERSION="1.24.4.dev20240211+cuda10.2.torch1.9.0" ARG K2_VERSION="1.24.4.dev20240223+cuda10.2.torch1.9.0"
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda10.2.torch1.9.0" ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda10.2.torch1.9.0"
ARG TORCHAUDIO_VERSION="0.9.0" ARG TORCHAUDIO_VERSION="0.9.0"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>" LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"

View File

@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive ARG DEBIAN_FRONTEND=noninteractive
# python 3.10 # python 3.10
ARG K2_VERSION="1.24.4.dev20240211+cuda11.7.torch2.0.0" ARG K2_VERSION="1.24.4.dev20240223+cuda11.7.torch2.0.0"
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.7.torch2.0.0" ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.7.torch2.0.0"
ARG TORCHAUDIO_VERSION="2.0.0+cu117" ARG TORCHAUDIO_VERSION="2.0.0+cu117"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>" LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"

View File

@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive ARG DEBIAN_FRONTEND=noninteractive
# python 3.10 # python 3.10
ARG K2_VERSION="1.24.4.dev20240211+cuda11.8.torch2.1.0" ARG K2_VERSION="1.24.4.dev20240223+cuda11.8.torch2.1.0"
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.8.torch2.1.0" ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.8.torch2.1.0"
ARG TORCHAUDIO_VERSION="2.1.0+cu118" ARG TORCHAUDIO_VERSION="2.1.0+cu118"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>" LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"

View File

@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive ARG DEBIAN_FRONTEND=noninteractive
# python 3.10 # python 3.10
ARG K2_VERSION="1.24.4.dev20240211+cuda12.1.torch2.1.0" ARG K2_VERSION="1.24.4.dev20240223+cuda12.1.torch2.1.0"
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda12.1.torch2.1.0" ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda12.1.torch2.1.0"
ARG TORCHAUDIO_VERSION="2.1.0+cu121" ARG TORCHAUDIO_VERSION="2.1.0+cu121"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>" LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"

View File

@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive ARG DEBIAN_FRONTEND=noninteractive
# python 3.10 # python 3.10
ARG K2_VERSION="1.24.4.dev20240211+cuda11.8.torch2.2.0" ARG K2_VERSION="1.24.4.dev20240223+cuda11.8.torch2.2.0"
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.8.torch2.2.0" ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.8.torch2.2.0"
ARG TORCHAUDIO_VERSION="2.2.0+cu118" ARG TORCHAUDIO_VERSION="2.2.0+cu118"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>" LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"

View File

@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive ARG DEBIAN_FRONTEND=noninteractive
# python 3.10 # python 3.10
ARG K2_VERSION="1.24.4.dev20240211+cuda12.1.torch2.2.0" ARG K2_VERSION="1.24.4.dev20240223+cuda12.1.torch2.2.0"
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda12.1.torch2.2.0" ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda12.1.torch2.2.0"
ARG TORCHAUDIO_VERSION="2.2.0+cu121" ARG TORCHAUDIO_VERSION="2.2.0+cu121"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>" LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"

View File

@ -0,0 +1,70 @@
FROM pytorch/pytorch:2.2.1-cuda11.8-cudnn8-devel
ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.10
ARG K2_VERSION="1.24.4.dev20240223+cuda11.8.torch2.2.1"
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.8.torch2.2.1"
ARG TORCHAUDIO_VERSION="2.2.1+cu118"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
LABEL k2_version=${K2_VERSION}
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
LABEL github_repo="https://github.com/k2-fsa/icefall"
RUN apt-get update && \
apt-get install -y --no-install-recommends \
curl \
vim \
libssl-dev \
autoconf \
automake \
bzip2 \
ca-certificates \
ffmpeg \
g++ \
gfortran \
git \
libtool \
make \
patch \
sox \
subversion \
unzip \
valgrind \
wget \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/*
# Install dependencies
RUN pip install --no-cache-dir \
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
kaldi_native_io \
kaldialign \
kaldifst \
kaldilm \
sentencepiece>=0.1.96 \
tensorboard \
typeguard \
dill \
onnx \
onnxruntime \
onnxmltools \
multi_quantization \
typeguard \
numpy \
pytest \
graphviz
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
pip install --no-cache-dir -r requirements.txt
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall

View File

@ -0,0 +1,70 @@
FROM pytorch/pytorch:2.2.1-cuda12.1-cudnn8-devel
ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.10
ARG K2_VERSION="1.24.4.dev20240223+cuda12.1.torch2.2.1"
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda12.1.torch2.2.1"
ARG TORCHAUDIO_VERSION="2.2.1+cu121"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
LABEL k2_version=${K2_VERSION}
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
LABEL github_repo="https://github.com/k2-fsa/icefall"
RUN apt-get update && \
apt-get install -y --no-install-recommends \
curl \
vim \
libssl-dev \
autoconf \
automake \
bzip2 \
ca-certificates \
ffmpeg \
g++ \
gfortran \
git \
libtool \
make \
patch \
sox \
subversion \
unzip \
valgrind \
wget \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/*
# Install dependencies
RUN pip install --no-cache-dir \
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
kaldi_native_io \
kaldialign \
kaldifst \
kaldilm \
sentencepiece>=0.1.96 \
tensorboard \
typeguard \
dill \
onnx \
onnxruntime \
onnxmltools \
multi_quantization \
typeguard \
numpy \
pytest \
graphviz
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
pip install --no-cache-dir -r requirements.txt
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall

View File

@ -34,6 +34,8 @@ which will give you something like below:
.. code-block:: bash .. code-block:: bash
"torch2.2.1-cuda12.1"
"torch2.2.1-cuda11.8"
"torch2.2.0-cuda12.1" "torch2.2.0-cuda12.1"
"torch2.2.0-cuda11.8" "torch2.2.0-cuda11.8"
"torch2.1.0-cuda12.1" "torch2.1.0-cuda12.1"

View File

@ -0,0 +1,225 @@
Finetune from a pre-trained Zipformer model with adapters
=========================================================
This tutorial shows you how to fine-tune a pre-trained **Zipformer**
transducer model on a new dataset with adapters.
Adapters are compact and efficient module that can be integrated into a pre-trained model
to improve the model's performance on a new domain. Adapters are injected
between different modules in the well-trained neural network. During training, only the parameters
in the adapters will be updated. It achieves competitive performance
while requiring much less GPU memory than full fine-tuning. For more details about adapters,
please refer to the original `paper <https://arxiv.org/pdf/1902.00751.pdf#/>`_ for more details.
.. HINT::
We assume you have read the page :ref:`install icefall` and have setup
the environment for ``icefall``.
.. HINT::
We recommend you to use a GPU or several GPUs to run this recipe
For illustration purpose, we fine-tune the Zipformer transducer model
pre-trained on `LibriSpeech`_ on the small subset of `GigaSpeech`_. You could use your
own data for fine-tuning if you create a manifest for your new dataset.
Data preparation
----------------
Please follow the instructions in the `GigaSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/gigaspeech/ASR>`_
to prepare the fine-tune data used in this tutorial. We only require the small subset in GigaSpeech for this tutorial.
Model preparation
-----------------
We are using the Zipformer model trained on full LibriSpeech (960 hours) as the intialization. The
checkpoint of the model can be downloaded via the following command:
.. code-block:: bash
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
$ cd icefall-asr-librispeech-zipformer-2023-05-15/exp
$ git lfs pull --include "pretrained.pt"
$ ln -s pretrained.pt epoch-99.pt
$ cd ../data/lang_bpe_500
$ git lfs pull --include bpe.model
$ cd ../../..
Before fine-tuning, let's test the model's WER on the new domain. The following command performs
decoding on the GigaSpeech test sets:
.. code-block:: bash
./zipformer/decode_gigaspeech.py \
--epoch 99 \
--avg 1 \
--exp-dir icefall-asr-librispeech-zipformer-2023-05-15/exp \
--use-averaged-model 0 \
--max-duration 1000 \
--decoding-method greedy_search
You should see the following numbers:
.. code-block::
For dev, WER of different settings are:
greedy_search 20.06 best for dev
For test, WER of different settings are:
greedy_search 19.27 best for test
Fine-tune with adapter
----------------------
We insert 4 adapters with residual connection in each ``Zipformer2EncoderLayer``.
The original model parameters remain untouched during training and only the parameters of
the adapters are updated. The following command starts a fine-tuning experiment with adapters:
.. code-block:: bash
$ do_finetune=1
$ use_adapters=1
$ adapter_dim=8
$ ./zipformer_adapter/train.py \
--world-size 2 \
--num-epochs 20 \
--start-epoch 1 \
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
--use-fp16 1 \
--base-lr 0.045 \
--use-adapters $use_adapters --adapter-dim $adapter_dim \
--bpe-model data/lang_bpe_500/bpe.model \
--do-finetune $do_finetune \
--master-port 13022 \
--finetune-ckpt icefall-asr-librispeech-zipformer-2023-05-15/exp/pretrained.pt \
--max-duration 1000
The following arguments are related to fine-tuning:
- ``--do-finetune``
If True, do fine-tuning by initializing the model from a pre-trained checkpoint.
**Note that if you want to resume your fine-tuning experiment from certain epochs, you
need to set this to False.**
- ``use-adapters``
If adapters are used during fine-tuning.
- ``--adapter-dim``
The bottleneck dimension of the adapter module. Typically a small number.
You should notice that in the training log, the total number of trainale parameters is shown:
.. code-block::
2024-02-22 21:22:03,808 INFO [train.py:1277] A total of 761344 trainable parameters (1.148% of the whole model)
The trainable parameters only makes up 1.15% of the entire model parameters, so the training will be much faster
and requires less memory than full fine-tuning.
Decoding
--------
After training, let's test the WERs. To test the WERs on the GigaSpeech set,
you can execute the following command:
.. code-block:: bash
$ epoch=20
$ avg=10
$ use_adapters=1
$ adapter_dim=8
% ./zipformer/decode.py \
--epoch $epoch \
--avg $avg \
--use-averaged-model 1 \
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
--max-duration 600 \
--use-adapters $use_adapters \
--adapter-dim $adapter_dim \
--decoding-method greedy_search
You should see the following numbers:
.. code-block::
For dev, WER of different settings are:
greedy_search 15.44 best for dev
For test, WER of different settings are:
greedy_search 15.42 best for test
The WER on test set is improved from 19.27 to 15.42, demonstrating the effectiveness of adapters.
The same model can be used to perform decoding on LibriSpeech test sets. You can deactivate the adapters
to keep the same performance of the original model:
.. code-block:: bash
$ epoch=20
$ avg=1
$ use_adapters=0
$ adapter_dim=8
% ./zipformer/decode.py \
--epoch $epoch \
--avg $avg \
--use-averaged-model 1 \
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
--max-duration 600 \
--use-adapters $use_adapters \
--adapter-dim $adapter_dim \
--decoding-method greedy_search
.. code-block::
For dev, WER of different settings are:
greedy_search 2.23 best for test-clean
For test, WER of different settings are:
greedy_search 4.96 best for test-other
The numbers are the same as reported in `icefall <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md#normal-scaled-model-number-of-model-parameters-65549011-ie-6555-m>`_. So adapter-based
fine-tuning is also very flexible as the same model can be used for decoding on the original and target domain.
Export the model
----------------
After training, the model can be exported to ``onnx`` format easily using the following command:
.. code-block:: bash
$ use_adapters=1
$ adapter_dim=16
$ ./zipformer_adapter/export-onnx.py \
--tokens icefall-asr-librispeech-zipformer-2023-05-15/data/lang_bpe_500/tokens.txt \
--use-averaged-model 1 \
--epoch 20 \
--avg 10 \
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
--use-adapters $use_adapters \
--adapter-dim $adapter_dim \
--num-encoder-layers "2,2,3,4,3,2" \
--downsampling-factor "1,2,4,8,4,2" \
--feedforward-dim "512,768,1024,1536,1024,768" \
--num-heads "4,4,4,8,4,4" \
--encoder-dim "192,256,384,512,384,256" \
--query-head-dim 32 \
--value-head-dim 12 \
--pos-head-dim 4 \
--pos-dim 48 \
--encoder-unmasked-dim "192,192,256,256,256,192" \
--cnn-module-kernel "31,31,15,15,15,31" \
--decoder-dim 512 \
--joiner-dim 512 \
--causal False \
--chunk-size "16,32,64,-1" \
--left-context-frames "64,128,256,-1"

View File

@ -13,3 +13,4 @@ data to improve the performance on new domains.
:caption: Table of Contents :caption: Table of Contents
from_supervised/finetune_zipformer from_supervised/finetune_zipformer
adapter/finetune_adapter

View File

@ -1,11 +1,11 @@
VITS VITS-LJSpeech
=============== ===============
This tutorial shows you how to train an VITS model This tutorial shows you how to train an VITS model
with the `LJSpeech <https://keithito.com/LJ-Speech-Dataset/>`_ dataset. with the `LJSpeech <https://keithito.com/LJ-Speech-Dataset/>`_ dataset.
.. note:: .. note::
TTS related recipes require packages in ``requirements-tts.txt``. TTS related recipes require packages in ``requirements-tts.txt``.
.. note:: .. note::
@ -13,6 +13,14 @@ with the `LJSpeech <https://keithito.com/LJ-Speech-Dataset/>`_ dataset.
The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech <https://arxiv.org/pdf/2106.06103.pdf>`_ The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech <https://arxiv.org/pdf/2106.06103.pdf>`_
Install extra dependencies
--------------------------
.. code-block:: bash
pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html
pip install numba espnet_tts_frontend
Data preparation Data preparation
---------------- ----------------
@ -56,7 +64,8 @@ Training
--start-epoch 1 \ --start-epoch 1 \
--use-fp16 1 \ --use-fp16 1 \
--exp-dir vits/exp \ --exp-dir vits/exp \
--tokens data/tokens.txt --tokens data/tokens.txt \
--model-type high \
--max-duration 500 --max-duration 500
.. note:: .. note::
@ -64,6 +73,11 @@ Training
You can adjust the hyper-parameters to control the size of the VITS model and You can adjust the hyper-parameters to control the size of the VITS model and
the training configurations. For more details, please run ``./vits/train.py --help``. the training configurations. For more details, please run ``./vits/train.py --help``.
.. warning::
If you want a model that runs faster on CPU, please use ``--model-type low``
or ``--model-type medium``.
.. note:: .. note::
The training can take a long time (usually a couple of days). The training can take a long time (usually a couple of days).
@ -95,8 +109,8 @@ training part first. It will save the ground-truth and generated wavs to the dir
Export models Export models
------------- -------------
Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``: Currently we only support ONNX model exporting. It will generate one file in the given ``exp-dir``:
``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``. ``vits-epoch-*.onnx``.
.. code-block:: bash .. code-block:: bash
@ -120,4 +134,68 @@ Download pretrained models
If you don't want to train from scratch, you can download the pretrained models If you don't want to train from scratch, you can download the pretrained models
by visiting the following link: by visiting the following link:
- `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2023-11-29>`_ - ``--model-type=high``: `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2024-02-28>`_
- ``--model-type=medium``: `<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-medium-2024-03-12>`_
- ``--model-type=low``: `<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-low-2024-03-12>`_
Usage in sherpa-onnx
--------------------
The following describes how to test the exported ONNX model in `sherpa-onnx`_.
.. hint::
`sherpa-onnx`_ supports different programming languages, e.g., C++, C, Python,
Kotlin, Java, Swift, Go, C#, etc. It also supports Android and iOS.
We only describe how to use pre-built binaries from `sherpa-onnx`_ below.
Please refer to `<https://k2-fsa.github.io/sherpa/onnx/>`_
for more documentation.
Install sherpa-onnx
^^^^^^^^^^^^^^^^^^^
.. code-block:: bash
pip install sherpa-onnx
To check that you have installed `sherpa-onnx`_ successfully, please run:
.. code-block:: bash
which sherpa-onnx-offline-tts
sherpa-onnx-offline-tts --help
Download lexicon files
^^^^^^^^^^^^^^^^^^^^^^
.. code-block:: bash
cd /tmp
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
tar xf espeak-ng-data.tar.bz2
Run sherpa-onnx
^^^^^^^^^^^^^^^
.. code-block:: bash
cd egs/ljspeech/TTS
sherpa-onnx-offline-tts \
--vits-model=vits/exp/vits-epoch-1000.onnx \
--vits-tokens=data/tokens.txt \
--vits-data-dir=/tmp/espeak-ng-data \
--num-threads=1 \
--output-filename=./high.wav \
"Ask not what your country can do for you; ask what you can do for your country."
.. hint::
You can also use ``sherpa-onnx-offline-tts-play`` to play the audio
as it is generating.
You should get a file ``high.wav`` after running the above command.
Congratulations! You have successfully trained and exported a text-to-speech
model and run it with `sherpa-onnx`_.

View File

@ -1,11 +1,11 @@
VITS VITS-VCTK
=============== ===============
This tutorial shows you how to train an VITS model This tutorial shows you how to train an VITS model
with the `VCTK <https://datashare.ed.ac.uk/handle/10283/3443>`_ dataset. with the `VCTK <https://datashare.ed.ac.uk/handle/10283/3443>`_ dataset.
.. note:: .. note::
TTS related recipes require packages in ``requirements-tts.txt``. TTS related recipes require packages in ``requirements-tts.txt``.
.. note:: .. note::

View File

@ -19,7 +19,9 @@ The following table lists the differences among them.
| `transducer_stateless_modified` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` | | `transducer_stateless_modified` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` |
| `transducer_stateless_modified-2` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` + extra data | | `transducer_stateless_modified-2` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` + extra data |
| `pruned_transducer_stateless3` | Conformer (reworked) | Embedding + Conv1d | pruned RNN-T + reworked model with random combiner + using aidatatang_20zh as extra data| | `pruned_transducer_stateless3` | Conformer (reworked) | Embedding + Conv1d | pruned RNN-T + reworked model with random combiner + using aidatatang_20zh as extra data|
| `pruned_transducer_stateless7` | Zipformer | Embedding | pruned RNN-T + zipformer encoder + stateless decoder with context-size 1 | | `pruned_transducer_stateless7` | Zipformer | Embedding | pruned RNN-T + zipformer encoder + stateless decoder with context-size set to 1 |
| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe with context-size set to 1 |
The decoder in `transducer_stateless` is modified from the paper The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).

View File

@ -75,7 +75,7 @@ It's reworked Zipformer with Pruned RNNT loss, trained with Byte-level BPE, `voc
| fast beam search | 4.43 | 4.17 | --epoch 40 --avg 10 | | fast beam search | 4.43 | 4.17 | --epoch 40 --avg 10 |
```bash ```bash
./prepare.sh ./prepare.sh
export CUDA_VISIBLE_DEVICES="0,1" export CUDA_VISIBLE_DEVICES="0,1"

View File

@ -360,7 +360,7 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
fi fi
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
log "Stage 11: Train RNN LM model" log "Stage 12: Train RNN LM model"
python ../../../icefall/rnn_lm/train.py \ python ../../../icefall/rnn_lm/train.py \
--start-epoch 0 \ --start-epoch 0 \
--world-size 1 \ --world-size 1 \

View File

@ -89,6 +89,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -881,9 +882,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( raise_grad_scale_is_too_small_error()
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0

View File

@ -85,6 +85,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
@ -878,9 +879,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0

View File

@ -78,6 +78,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -871,9 +872,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]

View File

@ -250,7 +250,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--context-size", "--context-size",
type=int, type=int,
default=1, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
) )
parser.add_argument( parser.add_argument(

View File

@ -78,6 +78,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@ -882,9 +883,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]

View File

@ -78,6 +78,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@ -881,9 +882,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]

View File

@ -86,6 +86,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
@ -985,9 +986,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
save_bad_model() save_bad_model()
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr()) cur_lr = max(scheduler.get_last_lr())

View File

@ -83,6 +83,7 @@ from icefall.checkpoint import (
update_averaged_model, update_averaged_model,
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -570,9 +571,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
save_bad_model() save_bad_model()
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr()) cur_lr = max(scheduler.get_last_lr())

View File

@ -1,6 +1,6 @@
## Results ## Results
### Aishell2 char-based training results ### Aishell2 char-based training results
#### Pruned transducer stateless 5 #### Pruned transducer stateless 5

View File

@ -29,7 +29,14 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool from icefall.utils import get_executor, str2bool
@ -42,10 +49,12 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False): def compute_fbank_aishell2(
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
):
src_dir = Path("data/manifests") src_dir = Path("data/manifests")
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count()) num_jobs = min(8, os.cpu_count())
dataset_parts = ( dataset_parts = (
"train", "train",
@ -68,8 +77,12 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
list(manifests.keys()), list(manifests.keys()),
dataset_parts, dataset_parts,
) )
if whisper_fbank:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = WhisperFbank(
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items(): for partition, m in manifests.items():
@ -82,7 +95,7 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
supervisions=m["supervisions"], supervisions=m["supervisions"],
) )
if "train" in partition and perturb_speed: if "train" in partition and perturb_speed:
logging.info(f"Doing speed perturb") logging.info("Doing speed perturb")
cut_set = ( cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
) )
@ -111,7 +124,12 @@ def get_args():
default=False, default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
) )
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
return parser.parse_args() return parser.parse_args()
@ -122,5 +140,7 @@ if __name__ == "__main__":
args = get_args() args = get_args()
compute_fbank_aishell2( compute_fbank_aishell2(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed num_mel_bins=args.num_mel_bins,
perturb_speed=args.perturb_speed,
whisper_fbank=args.whisper_fbank,
) )

View File

@ -108,6 +108,16 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
fi fi
fi fi
whisper_mel_bins=80
if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then
log "Stage 30: Compute whisper fbank for aishell2"
if [ ! -f data/fbank/.aishell2.whisper.done ]; then
mkdir -p data/fbank
./local/compute_fbank_aishell2.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
touch data/fbank/.aishell2.whisper.done
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank for musan" log "Stage 4: Compute fbank for musan"
if [ ! -f data/fbank/.msuan.done ]; then if [ ! -f data/fbank/.msuan.done ]; then

View File

@ -3,7 +3,7 @@
This recipe contains some various ASR models trained with Aishell4 (including S, M and L three subsets). This recipe contains some various ASR models trained with Aishell4 (including S, M and L three subsets).
The AISHELL-4 is a sizable real-recorded Mandarin speech dataset collected by 8-channel circular microphone array for speech processing in conference scenarios. The dataset consists of 211 recorded meeting sessions, each containing 4 to 8 speakers, with a total length of 120 hours. This dataset aims to bridge the advanced research on multi-speaker processing and the practical application scenario in three aspects. With real recorded meetings, AISHELL-4 provides realistic acoustics and rich natural speech characteristics in conversation such as short pause, speech overlap, quick speaker turn, noise, etc. Meanwhile, the accurate transcription and speaker voice activity are provided for each meeting in AISHELL-4. This allows the researchers to explore different aspects in meeting processing, ranging from individual tasks such as speech front-end processing, speech recognition and speaker diarization, to multi-modality modeling and joint optimization of relevant tasks. The AISHELL-4 is a sizable real-recorded Mandarin speech dataset collected by 8-channel circular microphone array for speech processing in conference scenarios. The dataset consists of 211 recorded meeting sessions, each containing 4 to 8 speakers, with a total length of 120 hours. This dataset aims to bridge the advanced research on multi-speaker processing and the practical application scenario in three aspects. With real recorded meetings, AISHELL-4 provides realistic acoustics and rich natural speech characteristics in conversation such as short pause, speech overlap, quick speaker turn, noise, etc. Meanwhile, the accurate transcription and speaker voice activity are provided for each meeting in AISHELL-4. This allows the researchers to explore different aspects in meeting processing, ranging from individual tasks such as speech front-end processing, speech recognition and speaker diarization, to multi-modality modeling and joint optimization of relevant tasks.
(From [Open Speech and Language Resources](https://www.openslr.org/111/)) (From [Open Speech and Language Resources](https://www.openslr.org/111/))

View File

@ -29,7 +29,14 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool from icefall.utils import get_executor, str2bool
@ -42,10 +49,12 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False): def compute_fbank_aishell4(
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
):
src_dir = Path("data/manifests/aishell4") src_dir = Path("data/manifests/aishell4")
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count()) num_jobs = min(8, os.cpu_count())
dataset_parts = ( dataset_parts = (
"train_S", "train_S",
@ -70,7 +79,12 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
dataset_parts, dataset_parts,
) )
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) if whisper_fbank:
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items(): for partition, m in manifests.items():
@ -84,7 +98,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
supervisions=m["supervisions"], supervisions=m["supervisions"],
) )
if "train" in partition and perturb_speed: if "train" in partition and perturb_speed:
logging.info(f"Doing speed perturb") logging.info("Doing speed perturb")
cut_set = ( cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
) )
@ -95,7 +109,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
# when an executor is specified, make more partitions # when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80, num_jobs=num_jobs if ex is None else 80,
executor=ex, executor=ex,
storage_type=ChunkedLilcomHdf5Writer, storage_type=LilcomChunkyWriter,
) )
logging.info("About splitting cuts into smaller chunks") logging.info("About splitting cuts into smaller chunks")
@ -121,7 +135,12 @@ def get_args():
default=False, default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
) )
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
return parser.parse_args() return parser.parse_args()
@ -132,5 +151,7 @@ if __name__ == "__main__":
args = get_args() args = get_args()
compute_fbank_aishell4( compute_fbank_aishell4(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed num_mel_bins=args.num_mel_bins,
perturb_speed=args.perturb_speed,
whisper_fbank=args.whisper_fbank,
) )

View File

@ -6,7 +6,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail set -eou pipefail
stage=-1 stage=-1
stop_stage=100 stop_stage=7
perturb_speed=true perturb_speed=true
@ -76,11 +76,21 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
fi fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Process aishell4" log "Stage 2: Compute fbank for aishell4"
if [ ! -f data/fbank/aishell4/.fbank.done ]; then if [ ! -f data/fbank/aishell4/.fbank.done ]; then
mkdir -p data/fbank/aishell4 mkdir -p data/fbank
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed} ./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed}
touch data/fbank/aishell4/.fbank.done touch data/fbank/.fbank.done
fi
fi
whisper_mel_bins=80
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
log "Stage 20: Compute whisper fbank for aishell4"
if [ ! -f data/fbank/aishell4/.fbank.done ]; then
mkdir -p data/fbank
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
touch data/fbank/.fbank.done
fi fi
fi fi
@ -106,16 +116,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
fi fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Compute fbank for aishell4" log "Stage 5: Prepare char based lang"
if [ ! -f data/fbank/.aishell4.done ]; then
mkdir -p data/fbank
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed}
touch data/fbank/.aishell4.done
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare char based lang"
lang_char_dir=data/lang_char lang_char_dir=data/lang_char
mkdir -p $lang_char_dir mkdir -p $lang_char_dir

View File

@ -29,7 +29,14 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool from icefall.utils import get_executor, str2bool
@ -42,10 +49,12 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False): def compute_fbank_alimeeting(
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
):
src_dir = Path("data/manifests/alimeeting") src_dir = Path("data/manifests/alimeeting")
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count()) num_jobs = min(8, os.cpu_count())
dataset_parts = ( dataset_parts = (
"train", "train",
@ -53,7 +62,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
"test", "test",
) )
prefix = "alimeeting" prefix = "alimeeting-far"
suffix = "jsonl.gz" suffix = "jsonl.gz"
manifests = read_manifests_if_cached( manifests = read_manifests_if_cached(
dataset_parts=dataset_parts, dataset_parts=dataset_parts,
@ -70,7 +79,12 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
dataset_parts, dataset_parts,
) )
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) if whisper_fbank:
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items(): for partition, m in manifests.items():
@ -83,7 +97,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
supervisions=m["supervisions"], supervisions=m["supervisions"],
) )
if "train" in partition and perturb_speed: if "train" in partition and perturb_speed:
logging.info(f"Doing speed perturb") logging.info("Doing speed perturb")
cut_set = ( cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
) )
@ -121,7 +135,12 @@ def get_args():
default=False, default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
) )
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use the Whisper Fbank feature extractor. Default: False.",
)
return parser.parse_args() return parser.parse_args()
@ -132,5 +151,7 @@ if __name__ == "__main__":
args = get_args() args = get_args()
compute_fbank_alimeeting( compute_fbank_alimeeting(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed num_mel_bins=args.num_mel_bins,
perturb_speed=args.perturb_speed,
whisper_fbank=args.whisper_fbank,
) )

View File

@ -6,7 +6,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail set -eou pipefail
stage=-1 stage=-1
stop_stage=100 stop_stage=7
perturb_speed=true perturb_speed=true
# We assume dl_dir (download dir) contains the following # We assume dl_dir (download dir) contains the following
@ -15,7 +15,7 @@ perturb_speed=true
# #
# - $dl_dir/alimeeting # - $dl_dir/alimeeting
# This directory contains the following files downloaded from # This directory contains the following files downloaded from
# https://openslr.org/62/ # https://openslr.org/119/
# #
# - Train_Ali_far.tar.gz # - Train_Ali_far.tar.gz
# - Train_Ali_near.tar.gz # - Train_Ali_near.tar.gz
@ -66,10 +66,21 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
fi fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Process alimeeting" log "Stage 2: compute fbank for alimeeting"
if [ ! -f data/fbank/alimeeting/.fbank.done ]; then if [ ! -f data/fbank/.fbank.done ]; then
mkdir -p data/fbank/alimeeting mkdir -p data/fbank
./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed} ./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed}
touch data/fbank/.fbank.done
fi
fi
whisper_mel_bins=80
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
log "Stage 20: compute whisper fbank for alimeeting"
if [ ! -f data/fbank/.fbank.done ]; then
mkdir -p data/fbank
./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
touch data/fbank/.fbank.done
fi fi
fi fi
@ -95,16 +106,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
fi fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Compute fbank for alimeeting" log "Stage 5: Prepare char based lang"
if [ ! -f data/fbank/.alimeeting.done ]; then
mkdir -p data/fbank
./local/compute_fbank_alimeeting.py --perturb-speed True
touch data/fbank/.alimeeting.done
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare char based lang"
lang_char_dir=data/lang_char lang_char_dir=data/lang_char
mkdir -p $lang_char_dir mkdir -p $lang_char_dir

View File

@ -12,7 +12,7 @@ use_gss=true # Use GSS-based enhancement with MDM setting
# #
# - $dl_dir/alimeeting # - $dl_dir/alimeeting
# This directory contains the following files downloaded from # This directory contains the following files downloaded from
# https://openslr.org/62/ # https://openslr.org/119/
# #
# - Train_Ali_far.tar.gz # - Train_Ali_far.tar.gz
# - Train_Ali_near.tar.gz # - Train_Ali_near.tar.gz

View File

@ -70,6 +70,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@ -851,9 +852,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]

View File

@ -69,6 +69,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@ -842,9 +843,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]

View File

@ -75,6 +75,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -1138,9 +1139,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]

View File

@ -75,6 +75,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -1129,9 +1130,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]

View File

@ -1 +0,0 @@
../../../librispeech/ASR/local/compile_hlg.py

View File

@ -0,0 +1,168 @@
#!/usr/bin/env python3
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input lang_dir and generates HLG from
- H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
- L, the lexicon, built from lang_dir/L_disambig.pt
Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from data/lm/G_n_gram.fst.txt
The generated HLG is saved in $lang_dir/HLG.pt
"""
import argparse
import logging
from pathlib import Path
import k2
import torch
from icefall.lexicon import Lexicon
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lm",
type=str,
default="G_3_gram",
help="""Stem name for LM used in HLG compiling.
""",
)
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
""",
)
return parser.parse_args()
def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
"""
Args:
lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
lm:
The language stem base name.
Return:
An FSA representing HLG.
"""
lexicon = Lexicon(lang_dir)
max_token_id = max(lexicon.tokens)
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
if Path(f"{lang_dir}/lm/{lm}.pt").is_file():
logging.info(f"Loading pre-compiled {lm}")
d = torch.load(f"{lang_dir}/lm/{lm}.pt")
G = k2.Fsa.from_dict(d)
else:
logging.info(f"Loading {lm}.fst.txt")
with open(f"{lang_dir}/lm/{lm}.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
torch.save(G.as_dict(), f"{lang_dir}/lm/{lm}.pt")
first_token_disambig_id = lexicon.token_table["#0"]
first_word_disambig_id = lexicon.word_table["#0"]
L = k2.arc_sort(L)
G = k2.arc_sort(G)
logging.info("Intersecting L and G")
LG = k2.compose(L, G)
logging.info(f"LG shape: {LG.shape}")
logging.info("Connecting LG")
LG = k2.connect(LG)
logging.info(f"LG shape after k2.connect: {LG.shape}")
logging.info(type(LG.aux_labels))
logging.info("Determinizing LG")
LG = k2.determinize(LG)
logging.info(type(LG.aux_labels))
logging.info("Connecting LG after k2.determinize")
LG = k2.connect(LG)
logging.info("Removing disambiguation symbols on LG")
# LG.labels[LG.labels >= first_token_disambig_id] = 0
# see https://github.com/k2-fsa/k2/pull/1140
labels = LG.labels
labels[labels >= first_token_disambig_id] = 0
LG.labels = labels
assert isinstance(LG.aux_labels, k2.RaggedTensor)
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
LG = k2.connect(LG)
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)
logging.info("Composing H and LG")
# CAUTION: The name of the inner_labels is fixed
# to `tokens`. If you want to change it, please
# also change other places in icefall that are using
# it.
HLG = k2.compose(H, LG, inner_labels="tokens")
logging.info("Connecting LG")
HLG = k2.connect(HLG)
logging.info("Arc sorting LG")
HLG = k2.arc_sort(HLG)
logging.info(f"HLG.shape: {HLG.shape}")
return HLG
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
if (lang_dir / "HLG.pt").is_file():
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
return
logging.info(f"Processing {lang_dir}")
HLG = compile_HLG(lang_dir, args.lm)
logging.info(f"Saving HLG.pt to {lang_dir}")
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -1 +0,0 @@
../../../librispeech/ASR/local/compile_lg.py

View File

@ -0,0 +1,149 @@
#!/usr/bin/env python3
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Kang Wei,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input lang_dir and generates LG from
- L, the lexicon, built from lang_dir/L_disambig.pt
Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from lang_dir/lm/G_3_gram.fst.txt
The generated LG is saved in $lang_dir/LG.pt
"""
import argparse
import logging
from pathlib import Path
import k2
import torch
from icefall.lexicon import Lexicon
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
""",
)
parser.add_argument(
"--lm",
type=str,
default="G_3_gram",
help="""Stem name for LM used in HLG compiling.
""",
)
return parser.parse_args()
def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
"""
Args:
lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
Return:
An FSA representing LG.
"""
lexicon = Lexicon(lang_dir)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
if Path(f"{lang_dir}/lm/{lm}.pt").is_file():
logging.info(f"Loading pre-compiled {lm}")
d = torch.load(f"{lang_dir}/lm/{lm}.pt")
G = k2.Fsa.from_dict(d)
else:
logging.info(f"Loading {lm}.fst.txt")
with open(f"{lang_dir}/lm/{lm}.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
torch.save(G.as_dict(), f"{lang_dir}/lm/{lm}.pt")
first_token_disambig_id = lexicon.token_table["#0"]
first_word_disambig_id = lexicon.word_table["#0"]
L = k2.arc_sort(L)
G = k2.arc_sort(G)
logging.info("Intersecting L and G")
LG = k2.compose(L, G)
logging.info(f"LG shape: {LG.shape}")
logging.info("Connecting LG")
LG = k2.connect(LG)
logging.info(f"LG shape after k2.connect: {LG.shape}")
logging.info(type(LG.aux_labels))
logging.info("Determinizing LG")
LG = k2.determinize(LG, k2.DeterminizeWeightPushingType.kLogWeightPushing)
logging.info(type(LG.aux_labels))
logging.info("Connecting LG after k2.determinize")
LG = k2.connect(LG)
logging.info("Removing disambiguation symbols on LG")
# LG.labels[LG.labels >= first_token_disambig_id] = 0
# see https://github.com/k2-fsa/k2/pull/1140
labels = LG.labels
labels[labels >= first_token_disambig_id] = 0
LG.labels = labels
assert isinstance(LG.aux_labels, k2.RaggedTensor)
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
LG = k2.connect(LG)
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)
return LG
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
if (lang_dir / "LG.pt").is_file():
logging.info(f"{lang_dir}/LG.pt already exists - skipping")
return
logging.info(f"Processing {lang_dir}")
LG = compile_LG(lang_dir, args.lm)
logging.info(f"Saving LG.pt to {lang_dir}")
torch.save(LG.as_dict(), f"{lang_dir}/LG.pt")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -48,8 +48,27 @@ def normalize_text(utt: str, language: str) -> str:
utt = re.sub("", "'", utt) utt = re.sub("", "'", utt)
if language == "en": if language == "en":
return re.sub(r"[^a-zA-Z\s]", "", utt).upper() return re.sub(r"[^a-zA-Z\s]", "", utt).upper()
if language == "fr": elif language == "fr":
return re.sub(r"[^A-ZÀÂÆÇÉÈÊËÎÏÔŒÙÛÜ' ]", "", utt).upper() return re.sub(r"[^A-ZÀÂÆÇÉÈÊËÎÏÔŒÙÛÜ' ]", "", utt).upper()
elif language == "pl":
return re.sub(r"[^a-ząćęłńóśźżA-ZĄĆĘŁŃÓŚŹŻ' ]", "", utt).upper()
elif language == "yue":
return (
utt.replace(" ", "")
.replace("", "")
.replace("", " ")
.replace("", "")
.replace("", "")
.replace("?", "")
)
else:
raise NotImplementedError(
f"""
Text normalization not implemented for language: {language},
please consider implementing it in the local/preprocess_commonvoice.py
or raise an issue on GitHub to request it.
"""
)
def preprocess_commonvoice( def preprocess_commonvoice(

View File

@ -381,9 +381,11 @@ class CommonVoiceAsrDataModule:
def test_dataloaders(self, cuts: CutSet) -> DataLoader: def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset") logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset( test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) input_strategy=(
if self.args.on_the_fly_feats OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
else eval(self.args.input_strategy)(), if self.args.on_the_fly_feats
else eval(self.args.input_strategy)()
),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
sampler = DynamicBucketingSampler( sampler = DynamicBucketingSampler(

View File

@ -79,10 +79,10 @@ It will generate the following 3 files inside $repo/exp:
import argparse import argparse
import logging import logging
from icefall import is_module_available import torch
from onnx_pretrained import OnnxModel from onnx_pretrained import OnnxModel
import torch from icefall import is_module_available
def get_parser(): def get_parser():

View File

@ -79,6 +79,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -871,9 +872,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]

View File

@ -31,7 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
DynamicBucketingSampler, DynamicBucketingSampler,
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
PrecomputedFeatures, PrecomputedFeatures,
SingleCutSampler, SimpleCutSampler,
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
@ -232,7 +232,7 @@ class CommonVoiceAsrDataModule:
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append( transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
) )
else: else:
logging.info("Disable MUSAN") logging.info("Disable MUSAN")
@ -315,8 +315,8 @@ class CommonVoiceAsrDataModule:
drop_last=self.args.drop_last, drop_last=self.args.drop_last,
) )
else: else:
logging.info("Using SingleCutSampler.") logging.info("Using SimpleCutSampler.")
train_sampler = SingleCutSampler( train_sampler = SimpleCutSampler(
cuts_train, cuts_train,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,
@ -383,9 +383,11 @@ class CommonVoiceAsrDataModule:
def test_dataloaders(self, cuts: CutSet) -> DataLoader: def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset") logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset( test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) input_strategy=(
if self.args.on_the_fly_feats OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
else eval(self.args.input_strategy)(), if self.args.on_the_fly_feats
else eval(self.args.input_strategy)()
),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
sampler = DynamicBucketingSampler( sampler = DynamicBucketingSampler(

View File

@ -889,9 +889,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( raise RuntimeError(f", exiting: {cur_grad_scale}")
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]

View File

@ -81,6 +81,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -965,9 +966,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]

View File

@ -78,6 +78,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@ -888,9 +889,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]

View File

@ -81,6 +81,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@ -909,9 +910,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]

View File

@ -70,9 +70,9 @@ import logging
from pathlib import Path from pathlib import Path
import torch import torch
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
from tokenizer import Tokenizer from tokenizer import Tokenizer
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,

View File

@ -81,6 +81,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@ -908,9 +909,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]

View File

@ -23,6 +23,7 @@ from pathlib import Path
from lhotse import CutSet, SupervisionSegment from lhotse import CutSet, SupervisionSegment
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import str2bool from icefall.utils import str2bool
# Similar text filtering and normalization procedure as in: # Similar text filtering and normalization procedure as in:

View File

@ -76,6 +76,7 @@ from beam_search import (
) )
from gigaspeech_scoring import asr_text_post_processing from gigaspeech_scoring import asr_text_post_processing
from train import get_params, get_transducer_model from train import get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,

View File

@ -88,7 +88,7 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import GigaSpeechAsrDataModule from asr_datamodule import GigaSpeechAsrDataModule
from train import add_model_arguments, get_params, get_model from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,

View File

@ -51,7 +51,7 @@ from streaming_beam_search import (
) )
from torch import Tensor, nn from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_model from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,

View File

@ -89,6 +89,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -1031,9 +1032,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
save_bad_model() save_bad_model()
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr()) cur_lr = max(scheduler.get_last_lr())

View File

@ -42,12 +42,10 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import GigaSpeechAsrDataModule from asr_datamodule import GigaSpeechAsrDataModule
from beam_search import ( from beam_search import keywords_search
keywords_search, from lhotse.cut import Cut
)
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
from lhotse.cut import Cut
from icefall import ContextGraph from icefall import ContextGraph
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,

View File

@ -76,25 +76,6 @@ from torch import Tensor
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from icefall import diagnostics
from icefall.checkpoint import remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import (
save_checkpoint_with_global_batch_idx,
update_averaged_model,
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
MetricsTracker,
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
)
from train import ( from train import (
add_model_arguments, add_model_arguments,
add_training_arguments, add_training_arguments,
@ -110,6 +91,25 @@ from train import (
set_batch_count, set_batch_count,
) )
from icefall import diagnostics
from icefall.checkpoint import remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import (
save_checkpoint_with_global_batch_idx,
update_averaged_model,
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
MetricsTracker,
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -372,9 +372,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
save_bad_model() save_bad_model()
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr()) cur_lr = max(scheduler.get_last_lr())

View File

@ -89,6 +89,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -1034,9 +1035,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
save_bad_model() save_bad_model()
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr()) cur_lr = max(scheduler.get_last_lr())

View File

@ -85,6 +85,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -1169,9 +1170,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]

View File

@ -81,6 +81,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -1056,9 +1057,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]

View File

@ -232,7 +232,7 @@ class LibriHeavyAsrDataModule:
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append( transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
) )
else: else:
logging.info("Disable MUSAN") logging.info("Disable MUSAN")

View File

@ -93,6 +93,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -1036,9 +1037,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
save_bad_model() save_bad_model()
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr()) cur_lr = max(scheduler.get_last_lr())

View File

@ -425,9 +425,11 @@ class LibriHeavyAsrDataModule:
def test_dataloaders(self, cuts: CutSet) -> DataLoader: def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset") logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset( test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) input_strategy=(
if self.args.on_the_fly_feats OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
else PrecomputedFeatures(), if self.args.on_the_fly_feats
else PrecomputedFeatures()
),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
sampler = DynamicBucketingSampler( sampler = DynamicBucketingSampler(

View File

@ -103,6 +103,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -1051,9 +1052,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
save_bad_model() save_bad_model()
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr()) cur_lr = max(scheduler.get_last_lr())

View File

@ -117,6 +117,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -855,9 +856,9 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
decoder_dim=params.decoder_dim, decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim, joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
context_dim=4 * 768 context_dim=(
if params.context_injection 4 * 768 if params.context_injection else -1
else -1, # the output dim of text encoder ), # the output dim of text encoder
context_injection=params.context_injection, context_injection=params.context_injection,
) )
return joiner return joiner
@ -1398,9 +1399,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
save_bad_model() save_bad_model()
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr()) cur_lr = max(scheduler.get_last_lr())

View File

@ -35,8 +35,7 @@ The following table lists the differences among them.
| `lstm_transducer_stateless2` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gigaspeech (multi-dataset setup) | | `lstm_transducer_stateless2` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gigaspeech (multi-dataset setup) |
| `lstm_transducer_stateless3` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gradient filter + delay penalty | | `lstm_transducer_stateless3` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gradient filter + delay penalty |
| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe | | `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe |
| `zipformer_adapter` | Upgraded Zipformer | Embedding + Conv1d | Finetune `zipformer` with adapter | | `zipformer_adapter` | Upgraded Zipformer | Embedding + Conv1d | It supports domain adaptation of Zipformer using parameter efficient adapters |
| `zipformer_lora` | Upgraded Zipformer | Embedding + Conv1d | Finetune `zipformer` with LoRA |
The decoder in `transducer_stateless` is modified from the paper The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).

View File

@ -24,8 +24,7 @@ To run this file, do:
""" """
import torch import torch
from train import get_ctc_model, get_params
from train import get_params, get_ctc_model
def test_model(): def test_model():

View File

@ -59,9 +59,9 @@ import onnx
import torch import torch
import torch.nn as nn import torch.nn as nn
from decoder import Decoder from decoder import Decoder
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from emformer import Emformer from emformer import Emformer
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,

View File

@ -39,7 +39,7 @@ Usage of this script:
import argparse import argparse
import logging import logging
import math import math
from typing import List from typing import List, Optional
import kaldifeat import kaldifeat
import sentencepiece as spm import sentencepiece as spm
@ -47,7 +47,6 @@ import torch
import torchaudio import torchaudio
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from typing import Optional, List
def get_parser(): def get_parser():

View File

@ -31,28 +31,28 @@ https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stat
""" """
import argparse import argparse
import torch.multiprocessing as mp
import torch
import torch.nn as nn
import logging import logging
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional, Tuple
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple
import k2 import k2
import sentencepiece as spm import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import AsrDataModule from asr_datamodule import AsrDataModule
from beam_search import ( from beam_search import (
fast_beam_search_one_best, fast_beam_search_one_best,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
) )
from icefall.utils import AttributeDict, convert_timestamp, setup_logger
from lhotse import CutSet, load_manifest_lazy from lhotse import CutSet, load_manifest_lazy
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.supervision import AlignmentItem
from lhotse.serialization import SequentialJsonlWriter from lhotse.serialization import SequentialJsonlWriter
from lhotse.supervision import AlignmentItem
from icefall.utils import AttributeDict, convert_timestamp, setup_logger
def get_parser(): def get_parser():

View File

@ -73,12 +73,11 @@ It will generate the following 3 files inside $repo/exp:
import argparse import argparse
import logging import logging
import torch
from onnx_pretrained import OnnxModel from onnx_pretrained import OnnxModel
from icefall import is_module_available from icefall import is_module_available
import torch
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(

View File

@ -22,11 +22,12 @@ Usage: ./pruned_transducer_stateless/my_profile.py
import argparse import argparse
import logging import logging
import sentencepiece as spm import sentencepiece as spm
import torch import torch
from train import add_model_arguments, get_encoder_model, get_params
from icefall.profiler import get_model_profile from icefall.profiler import get_model_profile
from train import get_encoder_model, add_model_arguments, get_params
def get_parser(): def get_parser():

View File

@ -75,8 +75,7 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from onnx_pretrained import OnnxModel, greedy_search
from onnx_pretrained import greedy_search, OnnxModel
from icefall.utils import setup_logger, store_transcripts, write_error_stats from icefall.utils import setup_logger, store_transcripts, write_error_stats

View File

@ -78,10 +78,10 @@ It will generate the following 3 files inside $repo/exp:
import argparse import argparse
import logging import logging
from icefall import is_module_available import torch
from onnx_pretrained import OnnxModel from onnx_pretrained import OnnxModel
import torch from icefall import is_module_available
def get_parser(): def get_parser():

View File

@ -76,8 +76,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import AsrDataModule from asr_datamodule import AsrDataModule
from librispeech import LibriSpeech from librispeech import LibriSpeech
from onnx_pretrained import OnnxModel, greedy_search
from onnx_pretrained import greedy_search, OnnxModel
from icefall.utils import setup_logger, store_transcripts, write_error_stats from icefall.utils import setup_logger, store_transcripts, write_error_stats

View File

@ -22,15 +22,15 @@ Usage: ./pruned_transducer_stateless4/my_profile.py
import argparse import argparse
import logging import logging
from typing import Tuple
import sentencepiece as spm import sentencepiece as spm
import torch import torch
from scaling import BasicNorm, DoubleSwish
from typing import Tuple
from torch import Tensor, nn from torch import Tensor, nn
from train import add_model_arguments, get_encoder_model, get_joiner_model, get_params
from icefall.profiler import get_model_profile from icefall.profiler import get_model_profile
from scaling import BasicNorm, DoubleSwish
from train import get_encoder_model, get_joiner_model, add_model_arguments, get_params
def get_parser(): def get_parser():

View File

@ -82,8 +82,7 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from onnx_pretrained import OnnxModel, greedy_search
from onnx_pretrained import greedy_search, OnnxModel
from icefall.utils import setup_logger, store_transcripts, write_error_stats from icefall.utils import setup_logger, store_transcripts, write_error_stats

View File

@ -20,7 +20,6 @@ from typing import List
import k2 import k2
import torch import torch
from beam_search import Hypothesis, HypothesisList, get_hyps_shape from beam_search import Hypothesis, HypothesisList, get_hyps_shape
# The force alignment problem can be formulated as finding # The force alignment problem can be formulated as finding

View File

@ -107,9 +107,6 @@ import k2
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
# from asr_datamodule import LibriSpeechAsrDataModule
from gigaspeech import GigaSpeechAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_nbest, fast_beam_search_nbest,
@ -120,6 +117,9 @@ from beam_search import (
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
) )
# from asr_datamodule import LibriSpeechAsrDataModule
from gigaspeech import GigaSpeechAsrDataModule
from gigaspeech_scoring import asr_text_post_processing from gigaspeech_scoring import asr_text_post_processing
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model

View File

@ -80,6 +80,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -976,9 +977,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]

View File

@ -65,16 +65,15 @@ from typing import Dict, List
import sentencepiece as spm import sentencepiece as spm
import torch import torch
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import str2bool
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import str2bool
def get_parser(): def get_parser():

View File

@ -22,15 +22,15 @@ Usage: ./pruned_transducer_stateless7/my_profile.py
import argparse import argparse
import logging import logging
from typing import Tuple
import sentencepiece as spm import sentencepiece as spm
import torch import torch
from scaling import BasicNorm, DoubleSwish
from typing import Tuple
from torch import Tensor, nn from torch import Tensor, nn
from train import add_model_arguments, get_encoder_model, get_joiner_model, get_params
from icefall.profiler import get_model_profile from icefall.profiler import get_model_profile
from scaling import BasicNorm, DoubleSwish
from train import get_encoder_model, get_joiner_model, add_model_arguments, get_params
def get_parser(): def get_parser():

View File

@ -75,8 +75,7 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from onnx_pretrained import OnnxModel, greedy_search
from onnx_pretrained import greedy_search, OnnxModel
from icefall.utils import setup_logger, store_transcripts, write_error_stats from icefall.utils import setup_logger, store_transcripts, write_error_stats

View File

@ -24,7 +24,6 @@ To run this file, do:
""" """
import torch import torch
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
from train import get_params, get_transducer_model from train import get_params, get_transducer_model

View File

@ -81,6 +81,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -878,9 +879,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]

View File

@ -81,6 +81,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -902,9 +903,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]

Some files were not shown because too many files have changed in this diff Show More