Merge remote-tracking branch 'dan/master' into doc-force-alignment-kaldi

This commit is contained in:
Fangjun Kuang 2024-06-12 17:34:26 +08:00
commit cb21b878c0
575 changed files with 91526 additions and 1409 deletions

1
.github/scripts/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
piper_phonemize.html

94
.github/scripts/audioset/AT/run.sh vendored Executable file
View File

@ -0,0 +1,94 @@
#!/usr/bin/env bash
set -ex
python3 -m pip install onnxoptimizer onnxsim
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/audioset/AT
function test_pretrained() {
repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12
repo=$(basename $repo_url)
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
pushd $repo/exp
git lfs pull --include pretrained.pt
ln -s pretrained.pt epoch-99.pt
ls -lh
popd
log "test pretrained.pt"
python3 zipformer/pretrained.py \
--checkpoint $repo/exp/pretrained.pt \
--label-dict $repo/data/class_labels_indices.csv \
$repo/test_wavs/1.wav \
$repo/test_wavs/2.wav \
$repo/test_wavs/3.wav \
$repo/test_wavs/4.wav
log "test jit export"
ls -lh $repo/exp/
python3 zipformer/export.py \
--exp-dir $repo/exp \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
--jit 1
ls -lh $repo/exp/
log "test jit models"
python3 zipformer/jit_pretrained.py \
--nn-model-filename $repo/exp/jit_script.pt \
--label-dict $repo/data/class_labels_indices.csv \
$repo/test_wavs/1.wav \
$repo/test_wavs/2.wav \
$repo/test_wavs/3.wav \
$repo/test_wavs/4.wav
log "test onnx export"
ls -lh $repo/exp/
python3 zipformer/export-onnx.py \
--exp-dir $repo/exp \
--epoch 99 \
--avg 1 \
--use-averaged-model 0
ls -lh $repo/exp/
pushd $repo/exp/
mv model-epoch-99-avg-1.onnx model.onnx
mv model-epoch-99-avg-1.int8.onnx model.int8.onnx
popd
ls -lh $repo/exp/
log "test onnx models"
for m in model.onnx model.int8.onnx; do
log "$m"
python3 zipformer/onnx_pretrained.py \
--model-filename $repo/exp/model.onnx \
--label-dict $repo/data/class_labels_indices.csv \
$repo/test_wavs/1.wav \
$repo/test_wavs/2.wav \
$repo/test_wavs/3.wav \
$repo/test_wavs/4.wav
done
log "prepare data for uploading to huggingface"
dst=/icefall/model-onnx
mkdir -p $dst
cp -v $repo/exp/*.onnx $dst/
cp -v $repo/data/* $dst/
cp -av $repo/test_wavs $dst
ls -lh $dst
ls -lh $dst/test_wavs
}
test_pretrained

View File

@ -11,6 +11,7 @@ ARG _KALDIFEAT_VERSION="${KALDIFEAT_VERSION}+cpu.torch${TORCH_VERSION}"
RUN apt-get update -y && \ RUN apt-get update -y && \
apt-get install -qq -y \ apt-get install -qq -y \
cmake \
ffmpeg \ ffmpeg \
git \ git \
git-lfs \ git-lfs \
@ -35,7 +36,9 @@ RUN pip install --no-cache-dir \
\ \
git+https://github.com/lhotse-speech/lhotse \ git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${_KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cpu.html \ kaldifeat==${_KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cpu.html \
cython \
dill \ dill \
espnet_tts_frontend \
graphviz \ graphviz \
kaldi-decoder \ kaldi-decoder \
kaldi_native_io \ kaldi_native_io \
@ -44,10 +47,15 @@ RUN pip install --no-cache-dir \
kaldilm \ kaldilm \
matplotlib \ matplotlib \
multi_quantization \ multi_quantization \
numba \
numpy \ numpy \
onnxoptimizer \
onnxsim \
onnx \ onnx \
onnxmltools \ onnxmltools \
onnxruntime \ onnxruntime \
piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html \
pypinyin==0.50.0 \
pytest \ pytest \
sentencepiece>=0.1.96 \ sentencepiece>=0.1.96 \
six \ six \

View File

@ -6,8 +6,8 @@ import json
def version_gt(a, b): def version_gt(a, b):
a_major, a_minor = a.split(".")[:2] a_major, a_minor = list(map(int, a.split(".")))[:2]
b_major, b_minor = b.split(".")[:2] b_major, b_minor = list(map(int, b.split(".")))[:2]
if a_major > b_major: if a_major > b_major:
return True return True
@ -18,8 +18,8 @@ def version_gt(a, b):
def version_ge(a, b): def version_ge(a, b):
a_major, a_minor = a.split(".")[:2] a_major, a_minor = list(map(int, a.split(".")))[:2]
b_major, b_minor = b.split(".")[:2] b_major, b_minor = list(map(int, b.split(".")))[:2]
if a_major > b_major: if a_major > b_major:
return True return True
@ -43,11 +43,16 @@ def get_torchaudio_version(torch_version):
def get_matrix(): def get_matrix():
k2_version = "1.24.4.dev20231220" k2_version = "1.24.4.dev20240223"
kaldifeat_version = "1.25.3.dev20231221" kaldifeat_version = "1.25.4.dev20240223"
version = "1.2" version = "20240606"
python_version = ["3.8", "3.9", "3.10", "3.11"] python_version = ["3.8", "3.9", "3.10", "3.11", "3.12"]
torch_version = ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] torch_version = []
torch_version += ["1.13.0", "1.13.1"]
torch_version += ["2.0.0", "2.0.1"]
torch_version += ["2.1.0", "2.1.1", "2.1.2"]
torch_version += ["2.2.0", "2.2.1", "2.2.2"]
torch_version += ["2.3.0", "2.3.1"]
matrix = [] matrix = []
for p in python_version: for p in python_version:
@ -57,10 +62,27 @@ def get_matrix():
if version_gt(p, "3.10") and not version_gt(t, "2.0"): if version_gt(p, "3.10") and not version_gt(t, "2.0"):
continue continue
# only torch>=2.2.0 supports python 3.12
if version_gt(p, "3.11") and not version_gt(t, "2.1"):
continue
k2_version_2 = k2_version
kaldifeat_version_2 = kaldifeat_version
if t == "2.2.2":
k2_version_2 = "1.24.4.dev20240328"
kaldifeat_version_2 = "1.25.4.dev20240329"
elif t == "2.3.0":
k2_version_2 = "1.24.4.dev20240425"
kaldifeat_version_2 = "1.25.4.dev20240425"
elif t == "2.3.1":
k2_version_2 = "1.24.4.dev20240606"
kaldifeat_version_2 = "1.25.4.dev20240606"
matrix.append( matrix.append(
{ {
"k2-version": k2_version, "k2-version": k2_version_2,
"kaldifeat-version": kaldifeat_version, "kaldifeat-version": kaldifeat_version_2,
"version": version, "version": version,
"python-version": p, "python-version": p,
"torch-version": t, "torch-version": t,

View File

@ -0,0 +1,29 @@
#!/usr/bin/env python3
def main():
prefix = (
"https://github.com/csukuangfj/piper-phonemize/releases/download/2023.12.5/"
)
files = [
"piper_phonemize-1.2.0-cp310-cp310-macosx_10_14_x86_64.whl",
"piper_phonemize-1.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
"piper_phonemize-1.2.0-cp311-cp311-macosx_10_14_x86_64.whl",
"piper_phonemize-1.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
"piper_phonemize-1.2.0-cp312-cp312-macosx_10_14_x86_64.whl",
"piper_phonemize-1.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
"piper_phonemize-1.2.0-cp37-cp37m-macosx_10_14_x86_64.whl",
"piper_phonemize-1.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
"piper_phonemize-1.2.0-cp38-cp38-macosx_10_14_x86_64.whl",
"piper_phonemize-1.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
"piper_phonemize-1.2.0-cp39-cp39-macosx_10_14_x86_64.whl",
"piper_phonemize-1.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
]
with open("piper_phonemize.html", "w") as f:
for file in files:
url = prefix + file
f.write(f'<a href="{url}">{file}</a><br/>\n')
if __name__ == "__main__":
main()

View File

@ -15,9 +15,9 @@ function prepare_data() {
# cause OOM error for CI later. # cause OOM error for CI later.
mkdir -p download/lm mkdir -p download/lm
pushd download/lm pushd download/lm
wget -q http://www.openslr.org/resources/11/librispeech-vocab.txt wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-lm-norm.txt.gz
wget -q http://www.openslr.org/resources/11/librispeech-lexicon.txt wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-lexicon.txt
wget -q http://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-vocab.txt
ls -lh ls -lh
gunzip librispeech-lm-norm.txt.gz gunzip librispeech-lm-norm.txt.gz
@ -64,6 +64,46 @@ function run_diagnostics() {
--print-diagnostics 1 --print-diagnostics 1
} }
function test_streaming_zipformer_ctc_hlg() {
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18
log "Downloading pre-trained model from $repo_url"
git lfs install
git clone $repo_url
repo=$(basename $repo_url)
rm $repo/exp-ctc-rnnt-small/*.onnx
ls -lh $repo/exp-ctc-rnnt-small
# export models to onnx
./zipformer/export-onnx-streaming-ctc.py \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 3 \
--exp-dir $repo/exp-ctc-rnnt-small \
--causal 1 \
--use-ctc 1 \
--chunk-size 16 \
--left-context-frames 128 \
\
--num-encoder-layers 2,2,2,2,2,2 \
--feedforward-dim 512,768,768,768,768,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192
ls -lh $repo/exp-ctc-rnnt-small
for wav in 0.wav 1.wav 8k.wav; do
python3 ./zipformer/onnx_pretrained_ctc_HLG_streaming.py \
--nn-model $repo/exp-ctc-rnnt-small/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \
--words $repo/data/lang_bpe_500/words.txt \
--HLG $repo/data/lang_bpe_500/HLG.fst \
$repo/test_wavs/$wav
done
rm -rf $repo
}
function test_pruned_transducer_stateless_2022_03_12() { function test_pruned_transducer_stateless_2022_03_12() {
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
@ -1577,6 +1617,7 @@ function test_transducer_bpe_500_2021_12_23() {
prepare_data prepare_data
run_diagnostics run_diagnostics
test_streaming_zipformer_ctc_hlg
test_pruned_transducer_stateless_2022_03_12 test_pruned_transducer_stateless_2022_03_12
test_pruned_transducer_stateless2_2022_04_29 test_pruned_transducer_stateless2_2022_04_29
test_pruned_transducer_stateless3_2022_04_29 test_pruned_transducer_stateless3_2022_04_29

157
.github/scripts/ljspeech/TTS/run.sh vendored Executable file
View File

@ -0,0 +1,157 @@
#!/usr/bin/env bash
set -ex
python3 -m pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html
python3 -m pip install espnet_tts_frontend
python3 -m pip install numba
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/ljspeech/TTS
sed -i.bak s/600/8/g ./prepare.sh
sed -i.bak s/"first 100"/"first 3"/g ./prepare.sh
sed -i.bak s/500/5/g ./prepare.sh
git diff
function prepare_data() {
# We have created a subset of the data for testing
#
mkdir download
pushd download
wget -q https://huggingface.co/csukuangfj/ljspeech-subset-for-ci-test/resolve/main/LJSpeech-1.1.tar.bz2
tar xvf LJSpeech-1.1.tar.bz2
popd
./prepare.sh
tree .
}
function train() {
pushd ./vits
sed -i.bak s/200/3/g ./train.py
git diff .
popd
for t in low medium high; do
./vits/train.py \
--exp-dir vits/exp-$t \
--model-type $t \
--num-epochs 1 \
--save-every-n 1 \
--num-buckets 2 \
--tokens data/tokens.txt \
--max-duration 20
ls -lh vits/exp-$t
done
}
function infer() {
for t in low medium high; do
./vits/infer.py \
--num-buckets 2 \
--model-type $t \
--epoch 1 \
--exp-dir ./vits/exp-$t \
--tokens data/tokens.txt \
--max-duration 20
done
}
function export_onnx() {
for t in low medium high; do
./vits/export-onnx.py \
--model-type $t \
--epoch 1 \
--exp-dir ./vits/exp-$t \
--tokens data/tokens.txt
ls -lh vits/exp-$t/
done
}
function test_medium() {
git clone https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-medium-2024-03-12
./vits/export-onnx.py \
--model-type medium \
--epoch 820 \
--exp-dir ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp \
--tokens ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt
ls -lh ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp
./vits/test_onnx.py \
--model-filename ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp/vits-epoch-820.onnx \
--tokens ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt \
--output-filename /icefall/test-medium.wav
ls -lh /icefall/test-medium.wav
d=/icefall/vits-icefall-en_US-ljspeech-medium
mkdir $d
cp -v ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt $d/
cp -v ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp/vits-epoch-820.onnx $d/model.onnx
rm -rf icefall-tts-ljspeech-vits-medium-2024-03-12
pushd $d
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
tar xf espeak-ng-data.tar.bz2
rm espeak-ng-data.tar.bz2
cd ..
tar cjf vits-icefall-en_US-ljspeech-medium.tar.bz2 vits-icefall-en_US-ljspeech-medium
rm -rf vits-icefall-en_US-ljspeech-medium
ls -lh *.tar.bz2
popd
}
function test_low() {
git clone https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-low-2024-03-12
./vits/export-onnx.py \
--model-type low \
--epoch 1600 \
--exp-dir ./icefall-tts-ljspeech-vits-low-2024-03-12/exp \
--tokens ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt
ls -lh ./icefall-tts-ljspeech-vits-low-2024-03-12/exp
./vits/test_onnx.py \
--model-filename ./icefall-tts-ljspeech-vits-low-2024-03-12/exp/vits-epoch-1600.onnx \
--tokens ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt \
--output-filename /icefall/test-low.wav
ls -lh /icefall/test-low.wav
d=/icefall/vits-icefall-en_US-ljspeech-low
mkdir $d
cp -v ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt $d/
cp -v ./icefall-tts-ljspeech-vits-low-2024-03-12/exp/vits-epoch-1600.onnx $d/model.onnx
rm -rf icefall-tts-ljspeech-vits-low-2024-03-12
pushd $d
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
tar xf espeak-ng-data.tar.bz2
rm espeak-ng-data.tar.bz2
cd ..
tar cjf vits-icefall-en_US-ljspeech-low.tar.bz2 vits-icefall-en_US-ljspeech-low
rm -rf vits-icefall-en_US-ljspeech-low
ls -lh *.tar.bz2
popd
}
prepare_data
train
infer
export_onnx
rm -rf vits/exp-{low,medium,high}
test_medium
test_low

137
.github/workflows/audioset.yml vendored Normal file
View File

@ -0,0 +1,137 @@
name: audioset
on:
push:
branches:
- master
pull_request:
branches:
- master
workflow_dispatch:
concurrency:
group: audioset-${{ github.ref }}
cancel-in-progress: true
jobs:
generate_build_matrix:
if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa'
# see https://github.com/pytorch/pytorch/pull/50633
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Generating build matrix
id: set-matrix
run: |
# outputting for debugging purposes
python ./.github/scripts/docker/generate_build_matrix.py
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py)
echo "::set-output name=matrix::${MATRIX}"
audioset:
needs: generate_build_matrix
name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Free space
shell: bash
run: |
ls -lh
df -h
rm -rf /opt/hostedtoolcache
df -h
echo "pwd: $PWD"
echo "github.workspace ${{ github.workspace }}"
- name: Run tests
uses: addnab/docker-run-action@v3
with:
image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }}
options: |
--volume ${{ github.workspace }}/:/icefall
shell: bash
run: |
export PYTHONPATH=/icefall:$PYTHONPATH
cd /icefall
git config --global --add safe.directory /icefall
.github/scripts/audioset/AT/run.sh
- name: Show model files
shell: bash
run: |
sudo chown -R runner ./model-onnx
ls -lh ./model-onnx
chmod -x ./model-onnx/class_labels_indices.csv
echo "----------"
ls -lh ./model-onnx/*
- name: Upload model to huggingface
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push'
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
uses: nick-fields/retry@v3
with:
max_attempts: 20
timeout_seconds: 200
shell: bash
command: |
git config --global user.email "csukuangfj@gmail.com"
git config --global user.name "Fangjun Kuang"
rm -rf huggingface
export GIT_LFS_SKIP_SMUDGE=1
git clone https://huggingface.co/k2-fsa/sherpa-onnx-zipformer-audio-tagging-2024-04-09 huggingface
cd huggingface
git fetch
git pull
git merge -m "merge remote" --ff origin main
cp ../model-onnx/*.onnx ./
cp ../model-onnx/*.csv ./
cp -a ../model-onnx/test_wavs ./
ls -lh
git add .
git status
git commit -m "update models"
git status
git push https://csukuangfj:$HF_TOKEN@huggingface.co/k2-fsa/sherpa-onnx-zipformer-audio-tagging-2024-04-09 main || true
rm -rf huggingface
- name: Prepare for release
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push'
shell: bash
run: |
d=sherpa-onnx-zipformer-audio-tagging-2024-04-09
mv ./model-onnx $d
tar cjvf ${d}.tar.bz2 $d
ls -lh
- name: Release exported onnx models
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push'
uses: svenstaro/upload-release-action@v2
with:
file_glob: true
overwrite: true
file: sherpa-onnx-*.tar.bz2
repo_name: k2-fsa/sherpa-onnx
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
tag: audio-tagging-models

View File

@ -56,11 +56,14 @@ jobs:
- name: Build doc - name: Build doc
shell: bash shell: bash
run: | run: |
.github/scripts/generate-piper-phonemize-page.py
cd docs cd docs
python3 -m pip install -r ./requirements.txt python3 -m pip install -r ./requirements.txt
make html make html
touch build/html/.nojekyll touch build/html/.nojekyll
cp -v ../piper_phonemize.html ./build/html/
- name: Deploy - name: Deploy
uses: peaceiris/actions-gh-pages@v3 uses: peaceiris/actions-gh-pages@v3
with: with:

View File

@ -16,7 +16,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
image: ["torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] image: ["torch2.3.1-cuda12.1", "torch2.3.1-cuda11.8", "torch2.2.2-cuda12.1", "torch2.2.2-cuda11.8", "torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
steps: steps:
# refer to https://github.com/actions/checkout # refer to https://github.com/actions/checkout

102
.github/workflows/ljspeech.yml vendored Normal file
View File

@ -0,0 +1,102 @@
name: ljspeech
on:
push:
branches:
- master
pull_request:
branches:
- master
workflow_dispatch:
concurrency:
group: ljspeech-${{ github.ref }}
cancel-in-progress: true
jobs:
generate_build_matrix:
if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa'
# see https://github.com/pytorch/pytorch/pull/50633
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Generating build matrix
id: set-matrix
run: |
# outputting for debugging purposes
python ./.github/scripts/docker/generate_build_matrix.py
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py)
echo "::set-output name=matrix::${MATRIX}"
ljspeech:
needs: generate_build_matrix
name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Free space
shell: bash
run: |
ls -lh
df -h
rm -rf /opt/hostedtoolcache
df -h
echo "pwd: $PWD"
echo "github.workspace ${{ github.workspace }}"
- name: Run tests
uses: addnab/docker-run-action@v3
with:
image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }}
options: |
--volume ${{ github.workspace }}/:/icefall
shell: bash
run: |
export PYTHONPATH=/icefall:$PYTHONPATH
cd /icefall
git config --global --add safe.directory /icefall
.github/scripts/ljspeech/TTS/run.sh
- name: display files
shell: bash
run: |
ls -lh
- uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0'
with:
name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }}
path: ./*.wav
- uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0'
with:
name: generated-models-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}
path: ./*.wav
- name: Release exported onnx models
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push'
uses: svenstaro/upload-release-action@v2
with:
file_glob: true
overwrite: true
file: vits-icefall-*.tar.bz2
repo_name: k2-fsa/sherpa-onnx
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
tag: tts-models

View File

@ -14,13 +14,20 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
image: ["torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] image: ["torch2.3.1-cuda12.1", "torch2.3.1-cuda11.8", "torch2.2.2-cuda12.1", "torch2.2.2-cuda11.8", "torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
steps: steps:
# refer to https://github.com/actions/checkout # refer to https://github.com/actions/checkout
- uses: actions/checkout@v2 - uses: actions/checkout@v2
with: with:
fetch-depth: 0 fetch-depth: 0
- name: Free space
shell: bash
run: |
df -h
rm -rf /opt/hostedtoolcache
df -h
- name: Run the build process with Docker - name: Run the build process with Docker
uses: addnab/docker-run-action@v3 uses: addnab/docker-run-action@v3
with: with:

View File

@ -49,7 +49,7 @@ jobs:
- name: Install Python dependencies - name: Install Python dependencies
run: | run: |
python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0 python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0 isort==5.10.1
# Click issue fixed in https://github.com/psf/black/pull/2966 # Click issue fixed in https://github.com/psf/black/pull/2966
- name: Run flake8 - name: Run flake8
@ -67,3 +67,9 @@ jobs:
working-directory: ${{github.workspace}} working-directory: ${{github.workspace}}
run: | run: |
black --check --diff . black --check --diff .
- name: Run isort
shell: bash
working-directory: ${{github.workspace}}
run: |
isort --check --diff .

View File

@ -59,4 +59,7 @@ jobs:
cd /icefall cd /icefall
git config --global --add safe.directory /icefall git config --global --add safe.directory /icefall
python3 -m torch.utils.collect_env
python3 -m k2.version
.github/scripts/yesno/ASR/run.sh .github/scripts/yesno/ASR/run.sh

View File

@ -26,7 +26,7 @@ repos:
# E121,E123,E126,E226,E24,E704,W503,W504 # E121,E123,E126,E226,E24,E704,W503,W504
- repo: https://github.com/pycqa/isort - repo: https://github.com/pycqa/isort
rev: 5.11.5 rev: 5.12.0
hooks: hooks:
- id: isort - id: isort
args: ["--profile=black"] args: ["--profile=black"]

View File

@ -74,6 +74,9 @@ The [LibriSpeech][librispeech] recipe supports the most comprehensive set of mod
- LSTM-based Predictor - LSTM-based Predictor
- [Stateless Predictor](https://research.google/pubs/rnn-transducer-with-stateless-prediction-network/) - [Stateless Predictor](https://research.google/pubs/rnn-transducer-with-stateless-prediction-network/)
#### Whisper
- [OpenAi Whisper](https://arxiv.org/abs/2212.04356) (We support fine-tuning on AiShell-1.)
If you are willing to contribute to icefall, please refer to [contributing](https://icefall.readthedocs.io/en/latest/contributing/index.html) for more details. If you are willing to contribute to icefall, please refer to [contributing](https://icefall.readthedocs.io/en/latest/contributing/index.html) for more details.
We would like to highlight the performance of some of the recipes here. We would like to highlight the performance of some of the recipes here.

View File

@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive ARG DEBIAN_FRONTEND=noninteractive
# python 3.7 # python 3.7
ARG K2_VERSION="1.24.4.dev20230725+cuda11.3.torch1.12.1" ARG K2_VERSION="1.24.4.dev20240223+cuda11.3.torch1.12.1"
ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.3.torch1.12.1" ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.3.torch1.12.1"
ARG TORCHAUDIO_VERSION="0.12.1+cu113" ARG TORCHAUDIO_VERSION="0.12.1+cu113"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>" LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
@ -55,6 +55,8 @@ RUN pip install --no-cache-dir \
onnx \ onnx \
onnxruntime \ onnxruntime \
onnxmltools \ onnxmltools \
onnxoptimizer \
onnxsim \
multi_quantization \ multi_quantization \
typeguard \ typeguard \
numpy \ numpy \

View File

@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive ARG DEBIAN_FRONTEND=noninteractive
# python 3.9 # python 3.9
ARG K2_VERSION="1.24.4.dev20231021+cuda11.6.torch1.13.0" ARG K2_VERSION="1.24.4.dev20240223+cuda11.6.torch1.13.0"
ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.6.torch1.13.0" ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.6.torch1.13.0"
ARG TORCHAUDIO_VERSION="0.13.0+cu116" ARG TORCHAUDIO_VERSION="0.13.0+cu116"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>" LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
@ -55,6 +55,8 @@ RUN pip install --no-cache-dir \
onnx \ onnx \
onnxruntime \ onnxruntime \
onnxmltools \ onnxmltools \
onnxoptimizer \
onnxsim \
multi_quantization \ multi_quantization \
typeguard \ typeguard \
numpy \ numpy \

View File

@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive ARG DEBIAN_FRONTEND=noninteractive
# python 3.7 # python 3.7
ARG K2_VERSION="1.24.3.dev20230726+cuda10.2.torch1.9.0" ARG K2_VERSION="1.24.4.dev20240223+cuda10.2.torch1.9.0"
ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda10.2.torch1.9.0" ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda10.2.torch1.9.0"
ARG TORCHAUDIO_VERSION="0.9.0" ARG TORCHAUDIO_VERSION="0.9.0"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>" LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
@ -69,6 +69,8 @@ RUN pip uninstall -y tqdm && \
onnx \ onnx \
onnxruntime \ onnxruntime \
onnxmltools \ onnxmltools \
onnxoptimizer \
onnxsim \
multi_quantization \ multi_quantization \
typeguard \ typeguard \
numpy \ numpy \

View File

@ -1,12 +1,13 @@
FROM pytorch/pytorch:2.0.0-cuda11.7-cudnn8-devel FROM pytorch/pytorch:2.0.0-cuda11.7-cudnn8-devel
# python 3.10
ENV LC_ALL C.UTF-8 ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive ARG DEBIAN_FRONTEND=noninteractive
# python 3.10 # python 3.10
ARG K2_VERSION="1.24.4.dev20231021+cuda11.7.torch2.0.0" ARG K2_VERSION="1.24.4.dev20240223+cuda11.7.torch2.0.0"
ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.7.torch2.0.0" ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.7.torch2.0.0"
ARG TORCHAUDIO_VERSION="2.0.0+cu117" ARG TORCHAUDIO_VERSION="2.0.0+cu117"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>" LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
@ -55,6 +56,8 @@ RUN pip install --no-cache-dir \
onnx \ onnx \
onnxruntime \ onnxruntime \
onnxmltools \ onnxmltools \
onnxoptimizer \
onnxsim \
multi_quantization \ multi_quantization \
typeguard \ typeguard \
numpy \ numpy \

View File

@ -1,12 +1,13 @@
FROM pytorch/pytorch:2.1.0-cuda11.8-cudnn8-devel FROM pytorch/pytorch:2.1.0-cuda11.8-cudnn8-devel
# python 3.10
ENV LC_ALL C.UTF-8 ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive ARG DEBIAN_FRONTEND=noninteractive
# python 3.10 # python 3.10
ARG K2_VERSION="1.24.4.dev20231021+cuda11.8.torch2.1.0" ARG K2_VERSION="1.24.4.dev20240223+cuda11.8.torch2.1.0"
ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.8.torch2.1.0" ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.8.torch2.1.0"
ARG TORCHAUDIO_VERSION="2.1.0+cu118" ARG TORCHAUDIO_VERSION="2.1.0+cu118"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>" LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
@ -55,6 +56,8 @@ RUN pip install --no-cache-dir \
onnx \ onnx \
onnxruntime \ onnxruntime \
onnxmltools \ onnxmltools \
onnxoptimizer \
onnxsim \
multi_quantization \ multi_quantization \
typeguard \ typeguard \
numpy \ numpy \

View File

@ -1,12 +1,13 @@
FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-devel FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-devel
# python 3.10
ENV LC_ALL C.UTF-8 ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive ARG DEBIAN_FRONTEND=noninteractive
# python 3.10 # python 3.10
ARG K2_VERSION="1.24.4.dev20231021+cuda12.1.torch2.1.0" ARG K2_VERSION="1.24.4.dev20240223+cuda12.1.torch2.1.0"
ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda12.1.torch2.1.0" ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda12.1.torch2.1.0"
ARG TORCHAUDIO_VERSION="2.1.0+cu121" ARG TORCHAUDIO_VERSION="2.1.0+cu121"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>" LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
@ -55,6 +56,8 @@ RUN pip install --no-cache-dir \
onnx \ onnx \
onnxruntime \ onnxruntime \
onnxmltools \ onnxmltools \
onnxoptimizer \
onnxsim \
multi_quantization \ multi_quantization \
typeguard \ typeguard \
numpy \ numpy \

View File

@ -0,0 +1,73 @@
FROM pytorch/pytorch:2.2.0-cuda11.8-cudnn8-devel
# python 3.10
ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.10
ARG K2_VERSION="1.24.4.dev20240223+cuda11.8.torch2.2.0"
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.8.torch2.2.0"
ARG TORCHAUDIO_VERSION="2.2.0+cu118"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
LABEL k2_version=${K2_VERSION}
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
LABEL github_repo="https://github.com/k2-fsa/icefall"
RUN apt-get update && \
apt-get install -y --no-install-recommends \
curl \
vim \
libssl-dev \
autoconf \
automake \
bzip2 \
ca-certificates \
ffmpeg \
g++ \
gfortran \
git \
libtool \
make \
patch \
sox \
subversion \
unzip \
valgrind \
wget \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/*
# Install dependencies
RUN pip install --no-cache-dir \
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
kaldi_native_io \
kaldialign \
kaldifst \
kaldilm \
sentencepiece>=0.1.96 \
tensorboard \
typeguard \
dill \
onnx \
onnxruntime \
onnxmltools \
onnxoptimizer \
onnxsim \
multi_quantization \
typeguard \
numpy \
pytest \
graphviz
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
pip install --no-cache-dir -r requirements.txt
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall

View File

@ -0,0 +1,73 @@
FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-devel
# python 3.10
ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.10
ARG K2_VERSION="1.24.4.dev20240223+cuda12.1.torch2.2.0"
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda12.1.torch2.2.0"
ARG TORCHAUDIO_VERSION="2.2.0+cu121"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
LABEL k2_version=${K2_VERSION}
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
LABEL github_repo="https://github.com/k2-fsa/icefall"
RUN apt-get update && \
apt-get install -y --no-install-recommends \
curl \
vim \
libssl-dev \
autoconf \
automake \
bzip2 \
ca-certificates \
ffmpeg \
g++ \
gfortran \
git \
libtool \
make \
patch \
sox \
subversion \
unzip \
valgrind \
wget \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/*
# Install dependencies
RUN pip install --no-cache-dir \
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
kaldi_native_io \
kaldialign \
kaldifst \
kaldilm \
sentencepiece>=0.1.96 \
tensorboard \
typeguard \
dill \
onnx \
onnxruntime \
onnxmltools \
onnxoptimizer \
onnxsim \
multi_quantization \
typeguard \
numpy \
pytest \
graphviz
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
pip install --no-cache-dir -r requirements.txt
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall

View File

@ -0,0 +1,73 @@
FROM pytorch/pytorch:2.2.1-cuda11.8-cudnn8-devel
# python 3.10
ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.10
ARG K2_VERSION="1.24.4.dev20240223+cuda11.8.torch2.2.1"
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.8.torch2.2.1"
ARG TORCHAUDIO_VERSION="2.2.1+cu118"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
LABEL k2_version=${K2_VERSION}
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
LABEL github_repo="https://github.com/k2-fsa/icefall"
RUN apt-get update && \
apt-get install -y --no-install-recommends \
curl \
vim \
libssl-dev \
autoconf \
automake \
bzip2 \
ca-certificates \
ffmpeg \
g++ \
gfortran \
git \
libtool \
make \
patch \
sox \
subversion \
unzip \
valgrind \
wget \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/*
# Install dependencies
RUN pip install --no-cache-dir \
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
kaldi_native_io \
kaldialign \
kaldifst \
kaldilm \
sentencepiece>=0.1.96 \
tensorboard \
typeguard \
dill \
onnx \
onnxruntime \
onnxmltools \
onnxoptimizer \
onnxsim \
multi_quantization \
typeguard \
numpy \
pytest \
graphviz
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
pip install --no-cache-dir -r requirements.txt
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall

View File

@ -0,0 +1,73 @@
FROM pytorch/pytorch:2.2.1-cuda12.1-cudnn8-devel
# python 3.10
ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.10
ARG K2_VERSION="1.24.4.dev20240223+cuda12.1.torch2.2.1"
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda12.1.torch2.2.1"
ARG TORCHAUDIO_VERSION="2.2.1+cu121"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
LABEL k2_version=${K2_VERSION}
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
LABEL github_repo="https://github.com/k2-fsa/icefall"
RUN apt-get update && \
apt-get install -y --no-install-recommends \
curl \
vim \
libssl-dev \
autoconf \
automake \
bzip2 \
ca-certificates \
ffmpeg \
g++ \
gfortran \
git \
libtool \
make \
patch \
sox \
subversion \
unzip \
valgrind \
wget \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/*
# Install dependencies
RUN pip install --no-cache-dir \
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
kaldi_native_io \
kaldialign \
kaldifst \
kaldilm \
sentencepiece>=0.1.96 \
tensorboard \
typeguard \
dill \
onnx \
onnxruntime \
onnxmltools \
onnxoptimizer \
onnxsim \
multi_quantization \
typeguard \
numpy \
pytest \
graphviz
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
pip install --no-cache-dir -r requirements.txt
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall

View File

@ -0,0 +1,73 @@
FROM pytorch/pytorch:2.2.2-cuda11.8-cudnn8-devel
# python 3.10
ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.10
ARG K2_VERSION="1.24.4.dev20240328+cuda11.8.torch2.2.2"
ARG KALDIFEAT_VERSION="1.25.4.dev20240329+cuda11.8.torch2.2.2"
ARG TORCHAUDIO_VERSION="2.2.2+cu118"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
LABEL k2_version=${K2_VERSION}
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
LABEL github_repo="https://github.com/k2-fsa/icefall"
RUN apt-get update && \
apt-get install -y --no-install-recommends \
curl \
vim \
libssl-dev \
autoconf \
automake \
bzip2 \
ca-certificates \
ffmpeg \
g++ \
gfortran \
git \
libtool \
make \
patch \
sox \
subversion \
unzip \
valgrind \
wget \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/*
# Install dependencies
RUN pip install --no-cache-dir \
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
kaldi_native_io \
kaldialign \
kaldifst \
kaldilm \
sentencepiece>=0.1.96 \
tensorboard \
typeguard \
dill \
onnx \
onnxruntime \
onnxmltools \
onnxoptimizer \
onnxsim \
multi_quantization \
typeguard \
numpy \
pytest \
graphviz
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
pip install --no-cache-dir -r requirements.txt
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall

View File

@ -0,0 +1,73 @@
FROM pytorch/pytorch:2.2.2-cuda12.1-cudnn8-devel
# python 3.10
ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.10
ARG K2_VERSION="1.24.4.dev20240328+cuda12.1.torch2.2.2"
ARG KALDIFEAT_VERSION="1.25.4.dev20240329+cuda12.1.torch2.2.2"
ARG TORCHAUDIO_VERSION="2.2.2+cu121"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
LABEL k2_version=${K2_VERSION}
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
LABEL github_repo="https://github.com/k2-fsa/icefall"
RUN apt-get update && \
apt-get install -y --no-install-recommends \
curl \
vim \
libssl-dev \
autoconf \
automake \
bzip2 \
ca-certificates \
ffmpeg \
g++ \
gfortran \
git \
libtool \
make \
patch \
sox \
subversion \
unzip \
valgrind \
wget \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/*
# Install dependencies
RUN pip install --no-cache-dir \
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
kaldi_native_io \
kaldialign \
kaldifst \
kaldilm \
sentencepiece>=0.1.96 \
tensorboard \
typeguard \
dill \
onnx \
onnxruntime \
onnxmltools \
onnxoptimizer \
onnxsim \
multi_quantization \
typeguard \
numpy \
pytest \
graphviz
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
pip install --no-cache-dir -r requirements.txt
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall

View File

@ -0,0 +1,73 @@
FROM pytorch/pytorch:2.3.1-cuda11.8-cudnn8-devel
# python 3.10
ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.10
ARG K2_VERSION="1.24.4.dev20240606+cuda11.8.torch2.3.1"
ARG KALDIFEAT_VERSION="1.25.4.dev20240606+cuda11.8.torch2.3.1"
ARG TORCHAUDIO_VERSION="2.3.1+cu118"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
LABEL k2_version=${K2_VERSION}
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
LABEL github_repo="https://github.com/k2-fsa/icefall"
RUN apt-get update && \
apt-get install -y --no-install-recommends \
curl \
vim \
libssl-dev \
autoconf \
automake \
bzip2 \
ca-certificates \
ffmpeg \
g++ \
gfortran \
git \
libtool \
make \
patch \
sox \
subversion \
unzip \
valgrind \
wget \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/*
# Install dependencies
RUN pip install --no-cache-dir \
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
kaldi_native_io \
kaldialign \
kaldifst \
kaldilm \
sentencepiece>=0.1.96 \
tensorboard \
typeguard \
dill \
onnx \
onnxruntime \
onnxmltools \
onnxoptimizer \
onnxsim \
multi_quantization \
typeguard \
numpy \
pytest \
graphviz
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
pip install --no-cache-dir -r requirements.txt
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall

View File

@ -0,0 +1,73 @@
FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-devel
# python 3.10
ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.10
ARG K2_VERSION="1.24.4.dev20240606+cuda12.1.torch2.3.1"
ARG KALDIFEAT_VERSION="1.25.4.dev20240606+cuda12.1.torch2.3.1"
ARG TORCHAUDIO_VERSION="2.3.1+cu121"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
LABEL k2_version=${K2_VERSION}
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
LABEL github_repo="https://github.com/k2-fsa/icefall"
RUN apt-get update && \
apt-get install -y --no-install-recommends \
curl \
vim \
libssl-dev \
autoconf \
automake \
bzip2 \
ca-certificates \
ffmpeg \
g++ \
gfortran \
git \
libtool \
make \
patch \
sox \
subversion \
unzip \
valgrind \
wget \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/*
# Install dependencies
RUN pip install --no-cache-dir \
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
kaldi_native_io \
kaldialign \
kaldifst \
kaldilm \
sentencepiece>=0.1.96 \
tensorboard \
typeguard \
dill \
onnx \
onnxruntime \
onnxmltools \
onnxoptimizer \
onnxsim \
multi_quantization \
typeguard \
numpy \
pytest \
graphviz
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
pip install --no-cache-dir -r requirements.txt
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall

View File

@ -30,7 +30,7 @@ of langugae model integration.
First, let's have a look at some background information. As the predecessor of LODR, Density Ratio (DR) is first proposed `here <https://arxiv.org/abs/2002.11268>`_ First, let's have a look at some background information. As the predecessor of LODR, Density Ratio (DR) is first proposed `here <https://arxiv.org/abs/2002.11268>`_
to address the language information mismatch between the training to address the language information mismatch between the training
corpus (source domain) and the testing corpus (target domain). Assuming that the source domain and the test domain corpus (source domain) and the testing corpus (target domain). Assuming that the source domain and the test domain
are acoustically similar, DR derives the following formular for decoding with Bayes' theorem: are acoustically similar, DR derives the following formula for decoding with Bayes' theorem:
.. math:: .. math::
@ -41,7 +41,7 @@ are acoustically similar, DR derives the following formular for decoding with Ba
where :math:`\lambda_1` and :math:`\lambda_2` are the weights of LM scores for target domain and source domain respectively. where :math:`\lambda_1` and :math:`\lambda_2` are the weights of LM scores for target domain and source domain respectively.
Here, the source domain LM is trained on the training corpus. The only difference in the above formular compared to Here, the source domain LM is trained on the training corpus. The only difference in the above formula compared to
shallow fusion is the subtraction of the source domain LM. shallow fusion is the subtraction of the source domain LM.
Some works treat the predictor and the joiner of the neural transducer as its internal LM. However, the LM is Some works treat the predictor and the joiner of the neural transducer as its internal LM. However, the LM is
@ -58,7 +58,7 @@ during decoding for transducer model:
In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Compared to DR, In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Compared to DR,
the only difference lies in the choice of source domain LM. According to the original `paper <https://arxiv.org/abs/2203.16776>`_, the only difference lies in the choice of source domain LM. According to the original `paper <https://arxiv.org/abs/2203.16776>`_,
LODR achieves similar performance compared DR in both intra-domain and cross-domain settings. LODR achieves similar performance compared to DR in both intra-domain and cross-domain settings.
As a bi-gram is much faster to evaluate, LODR is usually much faster. As a bi-gram is much faster to evaluate, LODR is usually much faster.
Now, we will show you how to use LODR in ``icefall``. Now, we will show you how to use LODR in ``icefall``.

View File

@ -139,7 +139,7 @@ A few parameters can be tuned to further boost the performance of shallow fusion
- ``--lm-scale`` - ``--lm-scale``
Controls the scale of the LM. If too small, the external language model may not be fully utilized; if too large, Controls the scale of the LM. If too small, the external language model may not be fully utilized; if too large,
the LM score may dominant during decoding, leading to bad WER. A typical value of this is around 0.3. the LM score might be dominant during decoding, leading to bad WER. A typical value of this is around 0.3.
- ``--beam-size`` - ``--beam-size``

View File

@ -74,6 +74,10 @@ to install dependencies of `icefall`_:
pip install k2==1.24.4.dev20231220+cpu.torch2.0.0 -f https://k2-fsa.github.io/k2/cpu.html pip install k2==1.24.4.dev20231220+cpu.torch2.0.0 -f https://k2-fsa.github.io/k2/cpu.html
# For users from China
# 中国国内用户,如果访问不了 huggingface, 请使用
# pip install k2==1.24.4.dev20231220+cpu.torch2.0.0 -f https://k2-fsa.github.io/k2/cpu-cn.html
# Install the latest version of lhotse # Install the latest version of lhotse
pip install git+https://github.com/lhotse-speech/lhotse pip install git+https://github.com/lhotse-speech/lhotse

View File

@ -206,6 +206,9 @@ We will install `k2`_ from pre-compiled wheels by following
.. code-block:: bash .. code-block:: bash
(test-icefall) kuangfangjun:~$ pip install k2==1.24.3.dev20230725+cuda11.6.torch1.13.0 -f https://k2-fsa.github.io/k2/cuda.html (test-icefall) kuangfangjun:~$ pip install k2==1.24.3.dev20230725+cuda11.6.torch1.13.0 -f https://k2-fsa.github.io/k2/cuda.html
# For users from China
# 中国国内用户,如果访问不了 huggingface, 请使用
# pip install k2==1.24.3.dev20230725+cuda11.6.torch1.13.0 -f https://k2-fsa.github.io/k2/cuda-cn.html
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Looking in links: https://k2-fsa.github.io/k2/cuda.html Looking in links: https://k2-fsa.github.io/k2/cuda.html

View File

@ -0,0 +1,225 @@
Finetune from a pre-trained Zipformer model with adapters
=========================================================
This tutorial shows you how to fine-tune a pre-trained **Zipformer**
transducer model on a new dataset with adapters.
Adapters are compact and efficient module that can be integrated into a pre-trained model
to improve the model's performance on a new domain. Adapters are injected
between different modules in the well-trained neural network. During training, only the parameters
in the adapters will be updated. It achieves competitive performance
while requiring much less GPU memory than full fine-tuning. For more details about adapters,
please refer to the original `paper <https://arxiv.org/pdf/1902.00751.pdf#/>`_ for more details.
.. HINT::
We assume you have read the page :ref:`install icefall` and have setup
the environment for ``icefall``.
.. HINT::
We recommend you to use a GPU or several GPUs to run this recipe
For illustration purpose, we fine-tune the Zipformer transducer model
pre-trained on `LibriSpeech`_ on the small subset of `GigaSpeech`_. You could use your
own data for fine-tuning if you create a manifest for your new dataset.
Data preparation
----------------
Please follow the instructions in the `GigaSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/gigaspeech/ASR>`_
to prepare the fine-tune data used in this tutorial. We only require the small subset in GigaSpeech for this tutorial.
Model preparation
-----------------
We are using the Zipformer model trained on full LibriSpeech (960 hours) as the intialization. The
checkpoint of the model can be downloaded via the following command:
.. code-block:: bash
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
$ cd icefall-asr-librispeech-zipformer-2023-05-15/exp
$ git lfs pull --include "pretrained.pt"
$ ln -s pretrained.pt epoch-99.pt
$ cd ../data/lang_bpe_500
$ git lfs pull --include bpe.model
$ cd ../../..
Before fine-tuning, let's test the model's WER on the new domain. The following command performs
decoding on the GigaSpeech test sets:
.. code-block:: bash
./zipformer/decode_gigaspeech.py \
--epoch 99 \
--avg 1 \
--exp-dir icefall-asr-librispeech-zipformer-2023-05-15/exp \
--use-averaged-model 0 \
--max-duration 1000 \
--decoding-method greedy_search
You should see the following numbers:
.. code-block::
For dev, WER of different settings are:
greedy_search 20.06 best for dev
For test, WER of different settings are:
greedy_search 19.27 best for test
Fine-tune with adapter
----------------------
We insert 4 adapters with residual connection in each ``Zipformer2EncoderLayer``.
The original model parameters remain untouched during training and only the parameters of
the adapters are updated. The following command starts a fine-tuning experiment with adapters:
.. code-block:: bash
$ do_finetune=1
$ use_adapters=1
$ adapter_dim=8
$ ./zipformer_adapter/train.py \
--world-size 2 \
--num-epochs 20 \
--start-epoch 1 \
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
--use-fp16 1 \
--base-lr 0.045 \
--use-adapters $use_adapters --adapter-dim $adapter_dim \
--bpe-model data/lang_bpe_500/bpe.model \
--do-finetune $do_finetune \
--master-port 13022 \
--finetune-ckpt icefall-asr-librispeech-zipformer-2023-05-15/exp/pretrained.pt \
--max-duration 1000
The following arguments are related to fine-tuning:
- ``--do-finetune``
If True, do fine-tuning by initializing the model from a pre-trained checkpoint.
**Note that if you want to resume your fine-tuning experiment from certain epochs, you
need to set this to False.**
- ``use-adapters``
If adapters are used during fine-tuning.
- ``--adapter-dim``
The bottleneck dimension of the adapter module. Typically a small number.
You should notice that in the training log, the total number of trainale parameters is shown:
.. code-block::
2024-02-22 21:22:03,808 INFO [train.py:1277] A total of 761344 trainable parameters (1.148% of the whole model)
The trainable parameters only makes up 1.15% of the entire model parameters, so the training will be much faster
and requires less memory than full fine-tuning.
Decoding
--------
After training, let's test the WERs. To test the WERs on the GigaSpeech set,
you can execute the following command:
.. code-block:: bash
$ epoch=20
$ avg=10
$ use_adapters=1
$ adapter_dim=8
% ./zipformer/decode.py \
--epoch $epoch \
--avg $avg \
--use-averaged-model 1 \
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
--max-duration 600 \
--use-adapters $use_adapters \
--adapter-dim $adapter_dim \
--decoding-method greedy_search
You should see the following numbers:
.. code-block::
For dev, WER of different settings are:
greedy_search 15.44 best for dev
For test, WER of different settings are:
greedy_search 15.42 best for test
The WER on test set is improved from 19.27 to 15.42, demonstrating the effectiveness of adapters.
The same model can be used to perform decoding on LibriSpeech test sets. You can deactivate the adapters
to keep the same performance of the original model:
.. code-block:: bash
$ epoch=20
$ avg=1
$ use_adapters=0
$ adapter_dim=8
% ./zipformer/decode.py \
--epoch $epoch \
--avg $avg \
--use-averaged-model 1 \
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
--max-duration 600 \
--use-adapters $use_adapters \
--adapter-dim $adapter_dim \
--decoding-method greedy_search
.. code-block::
For dev, WER of different settings are:
greedy_search 2.23 best for test-clean
For test, WER of different settings are:
greedy_search 4.96 best for test-other
The numbers are the same as reported in `icefall <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md#normal-scaled-model-number-of-model-parameters-65549011-ie-6555-m>`_. So adapter-based
fine-tuning is also very flexible as the same model can be used for decoding on the original and target domain.
Export the model
----------------
After training, the model can be exported to ``onnx`` format easily using the following command:
.. code-block:: bash
$ use_adapters=1
$ adapter_dim=16
$ ./zipformer_adapter/export-onnx.py \
--tokens icefall-asr-librispeech-zipformer-2023-05-15/data/lang_bpe_500/tokens.txt \
--use-averaged-model 1 \
--epoch 20 \
--avg 10 \
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
--use-adapters $use_adapters \
--adapter-dim $adapter_dim \
--num-encoder-layers "2,2,3,4,3,2" \
--downsampling-factor "1,2,4,8,4,2" \
--feedforward-dim "512,768,1024,1536,1024,768" \
--num-heads "4,4,4,8,4,4" \
--encoder-dim "192,256,384,512,384,256" \
--query-head-dim 32 \
--value-head-dim 12 \
--pos-head-dim 4 \
--pos-dim 48 \
--encoder-unmasked-dim "192,192,256,256,256,192" \
--cnn-module-kernel "31,31,15,15,15,31" \
--decoder-dim 512 \
--joiner-dim 512 \
--causal False \
--chunk-size "16,32,64,-1" \
--left-context-frames "64,128,256,-1"

View File

@ -0,0 +1,140 @@
Finetune from a supervised pre-trained Zipformer model
======================================================
This tutorial shows you how to fine-tune a supervised pre-trained **Zipformer**
transducer model on a new dataset.
.. HINT::
We assume you have read the page :ref:`install icefall` and have setup
the environment for ``icefall``.
.. HINT::
We recommend you to use a GPU or several GPUs to run this recipe
For illustration purpose, we fine-tune the Zipformer transducer model
pre-trained on `LibriSpeech`_ on the small subset of `GigaSpeech`_. You could use your
own data for fine-tuning if you create a manifest for your new dataset.
Data preparation
----------------
Please follow the instructions in the `GigaSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/gigaspeech/ASR>`_
to prepare the fine-tune data used in this tutorial. We only require the small subset in GigaSpeech for this tutorial.
Model preparation
-----------------
We are using the Zipformer model trained on full LibriSpeech (960 hours) as the intialization. The
checkpoint of the model can be downloaded via the following command:
.. code-block:: bash
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
$ cd icefall-asr-librispeech-zipformer-2023-05-15/exp
$ git lfs pull --include "pretrained.pt"
$ ln -s pretrained.pt epoch-99.pt
$ cd ../data/lang_bpe_500
$ git lfs pull --include bpe.model
$ cd ../../..
Before fine-tuning, let's test the model's WER on the new domain. The following command performs
decoding on the GigaSpeech test sets:
.. code-block:: bash
./zipformer/decode_gigaspeech.py \
--epoch 99 \
--avg 1 \
--exp-dir icefall-asr-librispeech-zipformer-2023-05-15/exp \
--use-averaged-model 0 \
--max-duration 1000 \
--decoding-method greedy_search
You should see the following numbers:
.. code-block::
For dev, WER of different settings are:
greedy_search 20.06 best for dev
For test, WER of different settings are:
greedy_search 19.27 best for test
Fine-tune
---------
Since LibriSpeech and GigaSpeech are both English dataset, we can initialize the whole
Zipformer model with the checkpoint downloaded in the previous step (otherwise we should consider
initializing the stateless decoder and joiner from scratch due to the mismatch of the output
vocabulary). The following command starts a fine-tuning experiment:
.. code-block:: bash
$ use_mux=0
$ do_finetune=1
$ ./zipformer/finetune.py \
--world-size 2 \
--num-epochs 20 \
--start-epoch 1 \
--exp-dir zipformer/exp_giga_finetune${do_finetune}_mux${use_mux} \
--use-fp16 1 \
--base-lr 0.0045 \
--bpe-model data/lang_bpe_500/bpe.model \
--do-finetune $do_finetune \
--use-mux $use_mux \
--master-port 13024 \
--finetune-ckpt icefall-asr-librispeech-zipformer-2023-05-15/exp/pretrained.pt \
--max-duration 1000
The following arguments are related to fine-tuning:
- ``--base-lr``
The learning rate used for fine-tuning. We suggest to set a **small** learning rate for fine-tuning,
otherwise the model may forget the initialization very quickly. A reasonable value should be around
1/10 of the original lr, i.e 0.0045.
- ``--do-finetune``
If True, do fine-tuning by initializing the model from a pre-trained checkpoint.
**Note that if you want to resume your fine-tuning experiment from certain epochs, you
need to set this to False.**
- ``--finetune-ckpt``
The path to the pre-trained checkpoint (used for initialization).
- ``--use-mux``
If True, mix the fine-tune data with the original training data by using `CutSet.mux <https://lhotse.readthedocs.io/en/latest/api.html#lhotse.supervision.SupervisionSet.mux>`_
This helps maintain the model's performance on the original domain if the original training
is available. **If you don't have the original training data, please set it to False.**
After fine-tuning, let's test the WERs. You can do this via the following command:
.. code-block:: bash
$ use_mux=0
$ do_finetune=1
$ ./zipformer/decode_gigaspeech.py \
--epoch 20 \
--avg 10 \
--exp-dir zipformer/exp_giga_finetune${do_finetune}_mux${use_mux} \
--use-averaged-model 1 \
--max-duration 1000 \
--decoding-method greedy_search
You should see numbers similar to the ones below:
.. code-block:: text
For dev, WER of different settings are:
greedy_search 13.47 best for dev
For test, WER of different settings are:
greedy_search 13.66 best for test
Compared to the original checkpoint, the fine-tuned model achieves much lower WERs
on the GigaSpeech test sets.

View File

@ -0,0 +1,16 @@
Fine-tune a pre-trained model
=============================
After pre-training on public available datasets, the ASR model is already capable of
performing general speech recognition with relatively high accuracy. However, the accuracy
could be still low on certain domains that are quite different from the original training
set. In this case, we can fine-tune the model with a small amount of additional labelled
data to improve the performance on new domains.
.. toctree::
:maxdepth: 2
:caption: Table of Contents
from_supervised/finetune_zipformer
adapter/finetune_adapter

View File

@ -1,4 +1,4 @@
VITS VITS-LJSpeech
=============== ===============
This tutorial shows you how to train an VITS model This tutorial shows you how to train an VITS model
@ -13,6 +13,14 @@ with the `LJSpeech <https://keithito.com/LJ-Speech-Dataset/>`_ dataset.
The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech <https://arxiv.org/pdf/2106.06103.pdf>`_ The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech <https://arxiv.org/pdf/2106.06103.pdf>`_
Install extra dependencies
--------------------------
.. code-block:: bash
pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html
pip install numba espnet_tts_frontend
Data preparation Data preparation
---------------- ----------------
@ -56,7 +64,8 @@ Training
--start-epoch 1 \ --start-epoch 1 \
--use-fp16 1 \ --use-fp16 1 \
--exp-dir vits/exp \ --exp-dir vits/exp \
--tokens data/tokens.txt --tokens data/tokens.txt \
--model-type high \
--max-duration 500 --max-duration 500
.. note:: .. note::
@ -64,6 +73,11 @@ Training
You can adjust the hyper-parameters to control the size of the VITS model and You can adjust the hyper-parameters to control the size of the VITS model and
the training configurations. For more details, please run ``./vits/train.py --help``. the training configurations. For more details, please run ``./vits/train.py --help``.
.. warning::
If you want a model that runs faster on CPU, please use ``--model-type low``
or ``--model-type medium``.
.. note:: .. note::
The training can take a long time (usually a couple of days). The training can take a long time (usually a couple of days).
@ -95,8 +109,8 @@ training part first. It will save the ground-truth and generated wavs to the dir
Export models Export models
------------- -------------
Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``: Currently we only support ONNX model exporting. It will generate one file in the given ``exp-dir``:
``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``. ``vits-epoch-*.onnx``.
.. code-block:: bash .. code-block:: bash
@ -120,4 +134,68 @@ Download pretrained models
If you don't want to train from scratch, you can download the pretrained models If you don't want to train from scratch, you can download the pretrained models
by visiting the following link: by visiting the following link:
- `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2023-11-29>`_ - ``--model-type=high``: `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2024-02-28>`_
- ``--model-type=medium``: `<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-medium-2024-03-12>`_
- ``--model-type=low``: `<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-low-2024-03-12>`_
Usage in sherpa-onnx
--------------------
The following describes how to test the exported ONNX model in `sherpa-onnx`_.
.. hint::
`sherpa-onnx`_ supports different programming languages, e.g., C++, C, Python,
Kotlin, Java, Swift, Go, C#, etc. It also supports Android and iOS.
We only describe how to use pre-built binaries from `sherpa-onnx`_ below.
Please refer to `<https://k2-fsa.github.io/sherpa/onnx/>`_
for more documentation.
Install sherpa-onnx
^^^^^^^^^^^^^^^^^^^
.. code-block:: bash
pip install sherpa-onnx
To check that you have installed `sherpa-onnx`_ successfully, please run:
.. code-block:: bash
which sherpa-onnx-offline-tts
sherpa-onnx-offline-tts --help
Download lexicon files
^^^^^^^^^^^^^^^^^^^^^^
.. code-block:: bash
cd /tmp
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
tar xf espeak-ng-data.tar.bz2
Run sherpa-onnx
^^^^^^^^^^^^^^^
.. code-block:: bash
cd egs/ljspeech/TTS
sherpa-onnx-offline-tts \
--vits-model=vits/exp/vits-epoch-1000.onnx \
--vits-tokens=data/tokens.txt \
--vits-data-dir=/tmp/espeak-ng-data \
--num-threads=1 \
--output-filename=./high.wav \
"Ask not what your country can do for you; ask what you can do for your country."
.. hint::
You can also use ``sherpa-onnx-offline-tts-play`` to play the audio
as it is generating.
You should get a file ``high.wav`` after running the above command.
Congratulations! You have successfully trained and exported a text-to-speech
model and run it with `sherpa-onnx`_.

View File

@ -1,4 +1,4 @@
VITS VITS-VCTK
=============== ===============
This tutorial shows you how to train an VITS model This tutorial shows you how to train an VITS model

View File

@ -17,3 +17,4 @@ We may add recipes for other tasks as well in the future.
Streaming-ASR/index Streaming-ASR/index
RNN-LM/index RNN-LM/index
TTS/index TTS/index
Finetune/index

View File

@ -16,8 +16,8 @@ perturb_speed=true
# #
# - $dl_dir/aidatatang_200zh # - $dl_dir/aidatatang_200zh
# You can find "corpus" and "transcript" inside it. # You can find "corpus" and "transcript" inside it.
# You can download it at # You can download it at https://openslr.org/62/
# https://openslr.org/62/ # If you download the data by yourself, DON'T FORGET to extract the *.tar.gz files under corpus.
dl_dir=$PWD/download dl_dir=$PWD/download

View File

@ -19,7 +19,9 @@ The following table lists the differences among them.
| `transducer_stateless_modified` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` | | `transducer_stateless_modified` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` |
| `transducer_stateless_modified-2` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` + extra data | | `transducer_stateless_modified-2` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` + extra data |
| `pruned_transducer_stateless3` | Conformer (reworked) | Embedding + Conv1d | pruned RNN-T + reworked model with random combiner + using aidatatang_20zh as extra data| | `pruned_transducer_stateless3` | Conformer (reworked) | Embedding + Conv1d | pruned RNN-T + reworked model with random combiner + using aidatatang_20zh as extra data|
| `pruned_transducer_stateless7` | Zipformer | Embedding | pruned RNN-T + zipformer encoder + stateless decoder with context-size 1 | | `pruned_transducer_stateless7` | Zipformer | Embedding | pruned RNN-T + zipformer encoder + stateless decoder with context-size set to 1 |
| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe with context-size set to 1 |
The decoder in `transducer_stateless` is modified from the paper The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).

View File

@ -1,4 +1,4 @@
Please visit Please visit
<https://icefall.readthedocs.io/en/latest/recipes/aishell/conformer_ctc.html> <https://k2-fsa.github.io/icefall/recipes/Non-streaming-ASR/aishell/conformer_ctc.html>
for how to run this recipe. for how to run this recipe.

View File

@ -419,7 +419,7 @@ def save_results(
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results) results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results, char_level=True)
if enable_log: if enable_log:
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -432,7 +432,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log f,
f"{test_set_name}-{key}",
results_char,
enable_log=enable_log,
compute_CER=True,
) )
test_set_wers[key] = wer test_set_wers[key] = wer

View File

@ -431,7 +431,7 @@ def save_results(
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results) results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results, char_level=True)
if enable_log: if enable_log:
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -444,7 +444,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log f,
f"{test_set_name}-{key}",
results_char,
enable_log=enable_log,
compute_CER=True,
) )
test_set_wers[key] = wer test_set_wers[key] = wer

View File

@ -360,7 +360,7 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
fi fi
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
log "Stage 11: Train RNN LM model" log "Stage 12: Train RNN LM model"
python ../../../icefall/rnn_lm/train.py \ python ../../../icefall/rnn_lm/train.py \
--start-epoch 0 \ --start-epoch 0 \
--world-size 1 \ --world-size 1 \

View File

@ -390,7 +390,7 @@ def save_results(
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results) results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
@ -402,7 +402,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True f,
f"{test_set_name}-{key}",
results_char,
enable_log=True,
compute_CER=True,
) )
test_set_wers[key] = wer test_set_wers[key] = wer

View File

@ -526,7 +526,7 @@ def save_results(
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results) results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
@ -538,7 +538,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True f,
f"{test_set_name}-{key}",
results_char,
enable_log=True,
compute_CER=True,
) )
test_set_wers[key] = wer test_set_wers[key] = wer

View File

@ -444,7 +444,7 @@ def save_results(
for res in results: for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
store_transcripts(filename=recog_path, texts=results_char) store_transcripts(filename=recog_path, texts=results_char, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
@ -452,7 +452,11 @@ def save_results(
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True f,
f"{test_set_name}-{key}",
results_char,
enable_log=True,
compute_CER=True,
) )
test_set_wers[key] = wer test_set_wers[key] = wer

View File

@ -89,6 +89,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -881,9 +882,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( raise_grad_scale_is_too_small_error()
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0

View File

@ -85,6 +85,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
@ -878,9 +879,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0

View File

@ -581,7 +581,7 @@ def save_results(
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results) results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
@ -594,7 +594,11 @@ def save_results(
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True f,
f"{test_set_name}-{key}",
results_char,
enable_log=True,
compute_CER=True,
) )
test_set_wers[key] = wer test_set_wers[key] = wer

View File

@ -78,6 +78,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -871,9 +872,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]

View File

@ -250,7 +250,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--context-size", "--context-size",
type=int, type=int,
default=1, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -492,7 +492,7 @@ def save_results(
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results) results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
@ -500,7 +500,11 @@ def save_results(
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True f,
f"{test_set_name}-{key}",
results,
enable_log=True,
compute_CER=True,
) )
test_set_wers[key] = wer test_set_wers[key] = wer

View File

@ -78,6 +78,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@ -882,9 +883,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]

View File

@ -78,6 +78,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@ -881,9 +882,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]

View File

@ -278,7 +278,7 @@ def save_results(
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results) results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
@ -289,7 +289,13 @@ def save_results(
for res in results: for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}-{key}", results_char) wer = write_error_stats(
f,
f"{test_set_name}-{key}",
results_char,
enable_log=True,
compute_CER=True,
)
test_set_wers[key] = wer test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename)) logging.info("Wrote detailed error stats to {}".format(errs_filename))

View File

@ -327,7 +327,7 @@ def save_results(
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results) results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results, char_level=True)
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs. # ref/hyp pairs.
@ -338,7 +338,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True f,
f"{test_set_name}-{key}",
results_char,
enable_log=True,
compute_CER=True,
) )
test_set_wers[key] = wer test_set_wers[key] = wer

View File

@ -372,7 +372,7 @@ def save_results(
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results) results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
@ -384,7 +384,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True f,
f"{test_set_name}-{key}",
results_char,
enable_log=True,
compute_CER=True,
) )
test_set_wers[key] = wer test_set_wers[key] = wer

View File

@ -376,7 +376,7 @@ def save_results(
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results) results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
@ -388,7 +388,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True f,
f"{test_set_name}-{key}",
results_char,
enable_log=True,
compute_CER=True,
) )
test_set_wers[key] = wer test_set_wers[key] = wer

View File

@ -214,7 +214,7 @@ def get_parser():
"--model-name", "--model-name",
type=str, type=str,
default="large-v2", default="large-v2",
choices=["large-v2", "large-v3", "medium", "small", "tiny"], choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"],
help="""The model name to use. help="""The model name to use.
""", """,
) )
@ -358,7 +358,7 @@ def save_results(
params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results) results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results, char_level=True)
if enable_log: if enable_log:
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -373,7 +373,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log f,
f"{test_set_name}-{key}",
results_char,
enable_log=enable_log,
compute_CER=True,
) )
test_set_wers[key] = wer test_set_wers[key] = wer

View File

@ -19,7 +19,7 @@
Usage: Usage:
#fine-tuning with deepspeed zero stage 1 #fine-tuning with deepspeed zero stage 1
torchrun --nproc-per-node 8 ./whisper/train.py \ torchrun --nproc_per_node 8 ./whisper/train.py \
--max-duration 200 \ --max-duration 200 \
--exp-dir whisper/exp_large_v2 \ --exp-dir whisper/exp_large_v2 \
--model-name large-v2 \ --model-name large-v2 \
@ -28,7 +28,7 @@ torchrun --nproc-per-node 8 ./whisper/train.py \
--deepspeed_config ./whisper/ds_config_zero1.json --deepspeed_config ./whisper/ds_config_zero1.json
# fine-tuning with ddp # fine-tuning with ddp
torchrun --nproc-per-node 8 ./whisper/train.py \ torchrun --nproc_per_node 8 ./whisper/train.py \
--max-duration 200 \ --max-duration 200 \
--exp-dir whisper/exp_medium \ --exp-dir whisper/exp_medium \
--manifest-dir data/fbank_whisper \ --manifest-dir data/fbank_whisper \
@ -136,7 +136,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless7/exp", default="whisper/exp",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
@ -147,7 +147,7 @@ def get_parser():
"--model-name", "--model-name",
type=str, type=str,
default="large-v2", default="large-v2",
choices=["large-v2", "large-v3", "medium", "small", "tiny"], choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"],
help="""The model name to use. help="""The model name to use.
""", """,
) )
@ -793,7 +793,7 @@ def run(rank, world_size, args):
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( opts = diagnostics.TensorDiagnosticOptions(
2**22 512
) # allow 4 megabytes per sub-module ) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts) diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -560,7 +560,7 @@ def save_results(
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results) results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
@ -570,7 +570,11 @@ def save_results(
) )
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True f,
f"{test_set_name}-{key}",
results,
enable_log=True,
compute_CER=True,
) )
test_set_wers[key] = wer test_set_wers[key] = wer

View File

@ -86,6 +86,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
@ -985,9 +986,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
save_bad_model() save_bad_model()
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr()) cur_lr = max(scheduler.get_last_lr())

View File

@ -83,6 +83,7 @@ from icefall.checkpoint import (
update_averaged_model, update_averaged_model,
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -570,9 +571,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
save_bad_model() save_bad_model()
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr()) cur_lr = max(scheduler.get_last_lr())

View File

@ -29,7 +29,14 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool from icefall.utils import get_executor, str2bool
@ -42,10 +49,12 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False): def compute_fbank_aishell2(
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
):
src_dir = Path("data/manifests") src_dir = Path("data/manifests")
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count()) num_jobs = min(8, os.cpu_count())
dataset_parts = ( dataset_parts = (
"train", "train",
@ -68,7 +77,11 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
list(manifests.keys()), list(manifests.keys()),
dataset_parts, dataset_parts,
) )
if whisper_fbank:
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.
@ -82,7 +95,7 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
supervisions=m["supervisions"], supervisions=m["supervisions"],
) )
if "train" in partition and perturb_speed: if "train" in partition and perturb_speed:
logging.info(f"Doing speed perturb") logging.info("Doing speed perturb")
cut_set = ( cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
) )
@ -111,7 +124,12 @@ def get_args():
default=False, default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
) )
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
return parser.parse_args() return parser.parse_args()
@ -122,5 +140,7 @@ if __name__ == "__main__":
args = get_args() args = get_args()
compute_fbank_aishell2( compute_fbank_aishell2(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed num_mel_bins=args.num_mel_bins,
perturb_speed=args.perturb_speed,
whisper_fbank=args.whisper_fbank,
) )

View File

@ -108,6 +108,16 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
fi fi
fi fi
whisper_mel_bins=80
if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then
log "Stage 30: Compute whisper fbank for aishell2"
if [ ! -f data/fbank/.aishell2.whisper.done ]; then
mkdir -p data/fbank
./local/compute_fbank_aishell2.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
touch data/fbank/.aishell2.whisper.done
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank for musan" log "Stage 4: Compute fbank for musan"
if [ ! -f data/fbank/.msuan.done ]; then if [ ! -f data/fbank/.msuan.done ]; then

View File

@ -29,7 +29,14 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool from icefall.utils import get_executor, str2bool
@ -42,10 +49,12 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False): def compute_fbank_aishell4(
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
):
src_dir = Path("data/manifests/aishell4") src_dir = Path("data/manifests/aishell4")
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count()) num_jobs = min(8, os.cpu_count())
dataset_parts = ( dataset_parts = (
"train_S", "train_S",
@ -70,6 +79,11 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
dataset_parts, dataset_parts,
) )
if whisper_fbank:
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.
@ -84,7 +98,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
supervisions=m["supervisions"], supervisions=m["supervisions"],
) )
if "train" in partition and perturb_speed: if "train" in partition and perturb_speed:
logging.info(f"Doing speed perturb") logging.info("Doing speed perturb")
cut_set = ( cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
) )
@ -95,7 +109,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
# when an executor is specified, make more partitions # when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80, num_jobs=num_jobs if ex is None else 80,
executor=ex, executor=ex,
storage_type=ChunkedLilcomHdf5Writer, storage_type=LilcomChunkyWriter,
) )
logging.info("About splitting cuts into smaller chunks") logging.info("About splitting cuts into smaller chunks")
@ -121,7 +135,12 @@ def get_args():
default=False, default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
) )
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
return parser.parse_args() return parser.parse_args()
@ -132,5 +151,7 @@ if __name__ == "__main__":
args = get_args() args = get_args()
compute_fbank_aishell4( compute_fbank_aishell4(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed num_mel_bins=args.num_mel_bins,
perturb_speed=args.perturb_speed,
whisper_fbank=args.whisper_fbank,
) )

View File

@ -6,7 +6,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail set -eou pipefail
stage=-1 stage=-1
stop_stage=100 stop_stage=7
perturb_speed=true perturb_speed=true
@ -76,11 +76,21 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
fi fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Process aishell4" log "Stage 2: Compute fbank for aishell4"
if [ ! -f data/fbank/aishell4/.fbank.done ]; then if [ ! -f data/fbank/aishell4/.fbank.done ]; then
mkdir -p data/fbank/aishell4 mkdir -p data/fbank
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed} ./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed}
touch data/fbank/aishell4/.fbank.done touch data/fbank/.fbank.done
fi
fi
whisper_mel_bins=80
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
log "Stage 20: Compute whisper fbank for aishell4"
if [ ! -f data/fbank/aishell4/.fbank.done ]; then
mkdir -p data/fbank
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
touch data/fbank/.fbank.done
fi fi
fi fi
@ -106,16 +116,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
fi fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Compute fbank for aishell4" log "Stage 5: Prepare char based lang"
if [ ! -f data/fbank/.aishell4.done ]; then
mkdir -p data/fbank
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed}
touch data/fbank/.aishell4.done
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare char based lang"
lang_char_dir=data/lang_char lang_char_dir=data/lang_char
mkdir -p $lang_char_dir mkdir -p $lang_char_dir

View File

@ -29,7 +29,14 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool from icefall.utils import get_executor, str2bool
@ -42,10 +49,12 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False): def compute_fbank_alimeeting(
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
):
src_dir = Path("data/manifests/alimeeting") src_dir = Path("data/manifests/alimeeting")
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count()) num_jobs = min(8, os.cpu_count())
dataset_parts = ( dataset_parts = (
"train", "train",
@ -53,7 +62,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
"test", "test",
) )
prefix = "alimeeting" prefix = "alimeeting-far"
suffix = "jsonl.gz" suffix = "jsonl.gz"
manifests = read_manifests_if_cached( manifests = read_manifests_if_cached(
dataset_parts=dataset_parts, dataset_parts=dataset_parts,
@ -70,6 +79,11 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
dataset_parts, dataset_parts,
) )
if whisper_fbank:
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.
@ -83,7 +97,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
supervisions=m["supervisions"], supervisions=m["supervisions"],
) )
if "train" in partition and perturb_speed: if "train" in partition and perturb_speed:
logging.info(f"Doing speed perturb") logging.info("Doing speed perturb")
cut_set = ( cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
) )
@ -121,7 +135,12 @@ def get_args():
default=False, default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
) )
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use the Whisper Fbank feature extractor. Default: False.",
)
return parser.parse_args() return parser.parse_args()
@ -132,5 +151,7 @@ if __name__ == "__main__":
args = get_args() args = get_args()
compute_fbank_alimeeting( compute_fbank_alimeeting(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed num_mel_bins=args.num_mel_bins,
perturb_speed=args.perturb_speed,
whisper_fbank=args.whisper_fbank,
) )

View File

@ -6,7 +6,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail set -eou pipefail
stage=-1 stage=-1
stop_stage=100 stop_stage=7
perturb_speed=true perturb_speed=true
# We assume dl_dir (download dir) contains the following # We assume dl_dir (download dir) contains the following
@ -15,7 +15,7 @@ perturb_speed=true
# #
# - $dl_dir/alimeeting # - $dl_dir/alimeeting
# This directory contains the following files downloaded from # This directory contains the following files downloaded from
# https://openslr.org/62/ # https://openslr.org/119/
# #
# - Train_Ali_far.tar.gz # - Train_Ali_far.tar.gz
# - Train_Ali_near.tar.gz # - Train_Ali_near.tar.gz
@ -66,10 +66,21 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
fi fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Process alimeeting" log "Stage 2: compute fbank for alimeeting"
if [ ! -f data/fbank/alimeeting/.fbank.done ]; then if [ ! -f data/fbank/.fbank.done ]; then
mkdir -p data/fbank/alimeeting mkdir -p data/fbank
./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed} ./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed}
touch data/fbank/.fbank.done
fi
fi
whisper_mel_bins=80
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
log "Stage 20: compute whisper fbank for alimeeting"
if [ ! -f data/fbank/.fbank.done ]; then
mkdir -p data/fbank
./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
touch data/fbank/.fbank.done
fi fi
fi fi
@ -95,16 +106,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
fi fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Compute fbank for alimeeting" log "Stage 5: Prepare char based lang"
if [ ! -f data/fbank/.alimeeting.done ]; then
mkdir -p data/fbank
./local/compute_fbank_alimeeting.py --perturb-speed True
touch data/fbank/.alimeeting.done
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare char based lang"
lang_char_dir=data/lang_char lang_char_dir=data/lang_char
mkdir -p $lang_char_dir mkdir -p $lang_char_dir

View File

@ -12,7 +12,7 @@ use_gss=true # Use GSS-based enhancement with MDM setting
# #
# - $dl_dir/alimeeting # - $dl_dir/alimeeting
# This directory contains the following files downloaded from # This directory contains the following files downloaded from
# https://openslr.org/62/ # https://openslr.org/119/
# #
# - Train_Ali_far.tar.gz # - Train_Ali_far.tar.gz
# - Train_Ali_near.tar.gz # - Train_Ali_near.tar.gz

View File

@ -70,6 +70,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@ -851,9 +852,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]

View File

@ -69,6 +69,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@ -842,9 +843,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]

View File

@ -75,6 +75,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -1138,9 +1139,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]

View File

@ -75,6 +75,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -1129,9 +1130,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( raise_grad_scale_is_too_small_error(cur_grad_scale)
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]

12
egs/audioset/AT/README.md Normal file
View File

@ -0,0 +1,12 @@
# Introduction
This is an audio tagging recipe for [Audioset](https://research.google.com/audioset/#/). It aims at predicting the sound events of an audio clip.
[./RESULTS.md](./RESULTS.md) contains the latest results.
# Zipformer
| Encoder | Feature type |
| --------| -------------|
| Zipformer | Frame level fbank|

View File

@ -0,0 +1,95 @@
## Results
### zipformer
See <https://github.com/k2-fsa/icefall/pull/1421> for more details
[zipformer](./zipformer)
#### normal-scaled model, number of model parameters: 65549011, i.e., 65.55 M
You can find a pretrained model, training logs, decoding logs, and decoding results at:
<https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12#/>
The model achieves the following mean averaged precision on AudioSet:
| Model | mAP |
| ------ | ------- |
| Zipformer-AT | 45.1 |
The training command is:
```bash
export CUDA_VISIBLE_DEVICES="4,5,6,7"
subset=full
python zipformer/train.py \
--world-size 4 \
--num-epochs 50 \
--exp-dir zipformer/exp_at_as_${subset} \
--start-epoch 1 \
--use-fp16 1 \
--num-events 527 \
--audioset-subset $subset \
--max-duration 1000 \
--enable-musan True \
--master-port 13455
```
The evaluation command is:
```bash
python zipformer/evaluate.py \
--epoch 32 \
--avg 8 \
--exp-dir zipformer/exp_at_as_full \
--max-duration 500
```
#### small-scaled model, number of model parameters: 22125218, i.e., 22.13 M
You can find a pretrained model, training logs, decoding logs, and decoding results at:
<https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-small-2024-04-23#/>
The model achieves the following mean averaged precision on AudioSet:
| Model | mAP |
| ------ | ------- |
| Zipformer-S-AT | 45.1 |
The training command is:
```bash
export CUDA_VISIBLE_DEVICES="4,5,6,7"
subset=full
python zipformer/train.py \
--world-size 4 \
--num-epochs 50 \
--exp-dir zipformer/exp_small_at_as_${subset} \
--start-epoch 1 \
--use-fp16 1 \
--num-events 527 \
--num-encoder-layers 2,2,2,2,2,2 \
--feedforward-dim 512,768,768,768,768,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192 \
--audioset-subset $subset \
--max-duration 1200 \
--enable-musan True \
--master-port 13455
```
The evaluation command is:
```bash
python zipformer/evaluate.py \
--epoch 31 \
--avg 4 \
--num-encoder-layers 2,2,2,2,2,2 \
--feedforward-dim 512,768,768,768,768,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192 \
--exp-dir zipformer/exp_small_at_as_full \
--max-duration 500
```

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/compute_fbank_musan.py

View File

@ -0,0 +1,177 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file generates the manifest and computes the fbank features for AudioSet
dataset. The generated manifests and features are stored in data/fbank.
"""
import argparse
import csv
import glob
import logging
import os
from typing import Dict
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.audio import Recording
from lhotse.cut import MonoCut
from lhotse.supervision import SupervisionSegment
from icefall.utils import get_executor
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_ID_mapping(csv_file):
# get a mapping between class ID and class name
mapping = {}
with open(csv_file, "r") as fin:
reader = csv.reader(fin, delimiter=",")
for i, row in enumerate(reader):
if i == 0:
continue
mapping[row[1]] = row[0]
return mapping
def parse_csv(csv_file: str, id_mapping: Dict):
# The content of the csv file shoud be something like this
# ------------------------------------------------------
# filename label
# dataset/AudioSet/balanced/xxxx.wav 0;451
# dataset/AudioSet/balanced/xxxy.wav 375
# ------------------------------------------------------
def name2id(names):
ids = [id_mapping[name] for name in names.split(",")]
return ";".join(ids)
mapping = {}
with open(csv_file, "r") as fin:
reader = csv.reader(fin, delimiter=" ")
for i, row in enumerate(reader):
if i <= 2:
continue
key = row[0].replace(",", "")
mapping[key] = name2id(row[-1])
return mapping
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("--dataset-dir", type=str, default="downloads/audioset")
parser.add_argument(
"--split",
type=str,
default="balanced",
choices=["balanced", "unbalanced", "eval"],
)
parser.add_argument(
"--feat-output-dir",
type=str,
default="data/fbank",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
dataset_dir = args.dataset_dir
split = args.split
feat_output_dir = args.feat_output_dir
num_jobs = min(15, os.cpu_count())
num_mel_bins = 80
if split in ["balanced", "unbalanced"]:
csv_file = f"{dataset_dir}/{split}_train_segments.csv"
elif split == "eval":
csv_file = f"{dataset_dir}/eval_segments.csv"
else:
raise ValueError()
class_indices_csv = f"{dataset_dir}/class_labels_indices.csv"
id_mapping = get_ID_mapping(class_indices_csv)
labels = parse_csv(csv_file, id_mapping)
audio_files = glob.glob(f"{dataset_dir}/{split}/*.wav")
new_cuts = []
for i, audio in enumerate(audio_files):
cut_id = audio.split("/")[-1].split("_")[0]
recording = Recording.from_file(audio, cut_id)
cut = MonoCut(
id=cut_id,
start=0.0,
duration=recording.duration,
channel=0,
recording=recording,
)
supervision = SupervisionSegment(
id=cut_id,
recording_id=cut.recording.id,
start=0.0,
channel=0,
duration=cut.duration,
)
try:
supervision.audio_event = labels[cut_id]
except KeyError:
logging.info(f"No labels found for {cut_id}.")
continue
cut.supervisions = [supervision]
new_cuts.append(cut)
if i % 100 == 0 and i:
logging.info(f"Processed {i} cuts until now.")
cuts = CutSet.from_cuts(new_cuts)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
logging.info(f"Computing fbank features for {split}")
with get_executor() as ex:
cuts = cuts.compute_and_store_features(
extractor=extractor,
storage_path=f"{feat_output_dir}/{split}_feats",
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
manifest_output_dir = feat_output_dir + "/" + f"cuts_audioset_{split}.jsonl.gz"
logging.info(f"Storing the manifest to {manifest_output_dir}")
cuts.to_jsonl(manifest_output_dir)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

104
egs/audioset/AT/prepare.sh Executable file
View File

@ -0,0 +1,104 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
# run step 0 to step 5 by default
stage=-1
stop_stage=4
dl_dir=$PWD/download
# we assume that you have your downloaded the AudioSet and placed
# it under $dl_dir/audioset, the folder structure should look like
# this:
# - $dl_dir/audioset
# - balanced
# - eval
# - unbalanced
# If you haven't downloaded the AudioSet, please refer to
# https://github.com/RicherMans/SAT/blob/main/datasets/audioset/1_download_audioset.sh.
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "Running prepare.sh"
log "dl_dir: $dl_dir"
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "Stage 0: Download the necessary csv files"
if [ ! -e $dl_dir/audioset/.csv.done]; then
wget --continue "http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/class_labels_indices.csv" -O "${dl_dir}/audioset/class_labels_indices.csv"
wget --continue http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/balanced_train_segments.csv -O "${dl_dir}/audioset/balanced_train_segments.csv"
wget --continue http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/eval_segments.csv -O "${dl_dir}/audioset/eval_segments.csv"
touch $dl_dir/audioset/.csv.done
fi
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Construct the audioset manifest and compute the fbank features for balanced set"
fbank_dir=data/fbank
if [! -e $fbank_dir/.balanced.done]; then
python local/generate_audioset_manifest.py \
--dataset-dir $dl_dir/audioset \
--split balanced \
--feat-output-dir $fbank_dir
touch $fbank_dir/.balanced.done
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Construct the audioset manifest and compute the fbank features for unbalanced set"
fbank_dir=data/fbank
if [! -e $fbank_dir/.unbalanced.done]; then
python local/generate_audioset_manifest.py \
--dataset-dir $dl_dir/audioset \
--split unbalanced \
--feat-output-dir $fbank_dir
touch $fbank_dir/.unbalanced.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Construct the audioset manifest and compute the fbank features for eval set"
fbank_dir=data/fbank
if [! -e $fbank_dir/.eval.done]; then
python local/generate_audioset_manifest.py \
--dataset-dir $dl_dir/audioset \
--split eval \
--feat-output-dir $fbank_dir
touch $fbank_dir/.eval.done
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare musan manifest"
# We assume that you have downloaded the musan corpus
# to $dl_dir/musan
mkdir -p data/manifests
if [ ! -e data/manifests/.musan.done ]; then
lhotse prepare musan $dl_dir/musan data/manifests
touch data/manifests/.musan.done
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank for musan"
mkdir -p data/fbank
if [ ! -e data/fbank/.musan.done ]; then
./local/compute_fbank_musan.py
touch data/fbank/.musan.done
fi
fi

1
egs/audioset/AT/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared

View File

@ -1,7 +1,6 @@
# Copyright 2021 Piotr Żelasko # Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../LICENSE for clarification regarding multiple authors
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -15,7 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import inspect import inspect
import logging import logging
@ -26,12 +24,12 @@ from typing import Any, Dict, Optional
import torch import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
AudioTaggingDataset,
CutConcatenate, CutConcatenate,
CutMix, CutMix,
DynamicBucketingSampler, DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures, PrecomputedFeatures,
SingleCutSampler, SimpleCutSampler,
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
@ -52,14 +50,12 @@ class _SeedWorkers:
fix_random_seed(self.seed + worker_id) fix_random_seed(self.seed + worker_id)
class CommonVoiceAsrDataModule: class AudioSetATDatamodule:
""" """
DataModule for k2 ASR experiments. DataModule for k2 audio tagging (AT) experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
It contains all the common data pipeline modules used in AT
experiments, e.g.: experiments, e.g.:
- dynamic batch size, - dynamic batch size,
- bucketing samplers, - bucketing samplers,
@ -67,7 +63,7 @@ class CommonVoiceAsrDataModule:
- augmentation, - augmentation,
- on-the-fly feature extraction - on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks. This class should be derived for specific corpora used in AT tasks.
""" """
def __init__(self, args: argparse.Namespace): def __init__(self, args: argparse.Namespace):
@ -76,7 +72,7 @@ class CommonVoiceAsrDataModule:
@classmethod @classmethod
def add_arguments(cls, parser: argparse.ArgumentParser): def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group( group = parser.add_argument_group(
title="ASR data related options", title="AT data related options",
description="These options are used for the preparation of " description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the " "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data " "effective batch sizes, sampling strategies, applied data "
@ -84,22 +80,17 @@ class CommonVoiceAsrDataModule:
) )
group.add_argument( group.add_argument(
"--language", "--audioset-subset",
type=str, type=str,
default="fr", default="balanced",
help="""Language of Common Voice""", choices=["balanced", "full"],
)
group.add_argument(
"--cv-manifest-dir",
type=Path,
default=Path("data/fr/fbank"),
help="Path to directory with CommonVoice train/dev/test cuts.",
) )
group.add_argument( group.add_argument(
"--manifest-dir", "--manifest-dir",
type=Path, type=Path,
default=Path("data/fbank"), default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.", help="Path to directory with audioset train/test cuts.",
) )
group.add_argument( group.add_argument(
"--max-duration", "--max-duration",
@ -218,7 +209,7 @@ class CommonVoiceAsrDataModule:
self, self,
cuts_train: CutSet, cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None, sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader: ):
""" """
Args: Args:
cuts_train: cuts_train:
@ -232,7 +223,7 @@ class CommonVoiceAsrDataModule:
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append( transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
) )
else: else:
logging.info("Disable MUSAN") logging.info("Disable MUSAN")
@ -278,7 +269,7 @@ class CommonVoiceAsrDataModule:
logging.info("Disable SpecAugment") logging.info("Disable SpecAugment")
logging.info("About to create train dataset") logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset( train = AudioTaggingDataset(
input_strategy=eval(self.args.input_strategy)(), input_strategy=eval(self.args.input_strategy)(),
cut_transforms=transforms, cut_transforms=transforms,
input_transforms=input_transforms, input_transforms=input_transforms,
@ -296,7 +287,7 @@ class CommonVoiceAsrDataModule:
# to be strict (e.g. could be randomized) # to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side. # Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset( train = AudioTaggingDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms, input_transforms=input_transforms,
@ -310,16 +301,15 @@ class CommonVoiceAsrDataModule:
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets, num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last, drop_last=self.args.drop_last,
) )
else: else:
logging.info("Using SingleCutSampler.") logging.info("Using SimpleCutSampler.")
train_sampler = SingleCutSampler( train_sampler = SimpleCutSampler(
cuts_train, cuts_train,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,
drop_last=self.args.drop_last,
) )
logging.info("About to create train dataloader") logging.info("About to create train dataloader")
@ -354,13 +344,13 @@ class CommonVoiceAsrDataModule:
logging.info("About to create dev dataset") logging.info("About to create dev dataset")
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset( validate = AudioTaggingDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
else: else:
validate = K2SpeechRecognitionDataset( validate = AudioTaggingDataset(
cut_transforms=transforms, cut_transforms=transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
@ -382,10 +372,12 @@ class CommonVoiceAsrDataModule:
def test_dataloaders(self, cuts: CutSet) -> DataLoader: def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset") logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset( test = AudioTaggingDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) input_strategy=(
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(), else eval(self.args.input_strategy)()
),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
sampler = DynamicBucketingSampler( sampler = DynamicBucketingSampler(
@ -403,22 +395,28 @@ class CommonVoiceAsrDataModule:
return test_dl return test_dl
@lru_cache() @lru_cache()
def train_cuts(self) -> CutSet: def audioset_train_cuts(self) -> CutSet:
logging.info("About to get train cuts") logging.info("About to get the audioset training cuts.")
return load_manifest_lazy( balanced_cuts = load_manifest_lazy(
self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_train.jsonl.gz" self.args.manifest_dir / "cuts_audioset_balanced.jsonl.gz"
) )
if self.args.audioset_subset == "full":
unbalanced_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_audioset_unbalanced.jsonl.gz"
)
cuts = CutSet.mux(
balanced_cuts,
unbalanced_cuts,
weights=[20000, 2000000],
stop_early=True,
)
else:
cuts = balanced_cuts
return cuts
@lru_cache() @lru_cache()
def dev_cuts(self) -> CutSet: def audioset_eval_cuts(self) -> CutSet:
logging.info("About to get dev cuts") logging.info("About to get audioset eval cuts")
return load_manifest_lazy( return load_manifest_lazy(
self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_dev.jsonl.gz" self.args.manifest_dir / "cuts_audioset_eval.jsonl.gz"
)
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(
self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_test.jsonl.gz"
) )

View File

@ -0,0 +1 @@
../../../librispeech/ASR/transducer_stateless/encoder_interface.py

View File

@ -0,0 +1,327 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
export CUDA_VISIBLE_DEVICES="0"
./zipformer/evaluate.py \
--epoch 50 \
--avg 10 \
--exp-dir zipformer/exp \
--max-duration 1000
"""
import argparse
import logging
from pathlib import Path
from typing import Dict
import torch
import torch.nn as nn
from at_datamodule import AudioSetATDatamodule
try:
from sklearn.metrics import average_precision_score
except:
raise ImportError(f"Please run\n" "pip3 install -U scikit-learn")
from train import add_model_arguments, get_model, get_params, str2multihot
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import AttributeDict, setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
add_model_arguments(parser)
return parser
def inference_one_batch(
params: AttributeDict,
model: nn.Module,
batch: dict,
):
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3, feature.shape
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
audio_event = supervisions["audio_event"]
label, _ = str2multihot(audio_event)
label = label.detach().cpu()
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
audio_logits = model.forward_audio_tagging(encoder_out, encoder_out_lens)
# convert to probabilities between 0-1
audio_logits = audio_logits.sigmoid().detach().cpu()
return audio_logits, label
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
) -> Dict:
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
all_logits = []
all_labels = []
for batch_idx, batch in enumerate(dl):
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
num_cuts += len(cut_ids)
audio_logits, labels = inference_one_batch(
params=params,
model=model,
batch=batch,
)
all_logits.append(audio_logits)
all_labels.append(labels)
if batch_idx % 20 == 1:
logging.info(f"Processed {num_cuts} cuts already.")
logging.info("Finish collecting audio logits")
return all_logits, all_labels
@torch.no_grad()
def main():
parser = get_parser()
AudioSetATDatamodule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.res_dir = params.exp_dir / "inference_audio_tagging"
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Evaluation started")
logging.info(params)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
),
strict=False,
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
),
strict=False,
)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
args.return_cuts = True
audioset = AudioSetATDatamodule(args)
audioset_cuts = audioset.audioset_eval_cuts()
audioset_dl = audioset.valid_dataloaders(audioset_cuts)
test_sets = ["audioset_eval"]
logits, labels = decode_dataset(
dl=audioset_dl,
params=params,
model=model,
)
logits = torch.cat(logits, dim=0).squeeze(dim=1).detach().numpy()
labels = torch.cat(labels, dim=0).long().detach().numpy()
# compute the metric
mAP = average_precision_score(
y_true=labels,
y_score=logits,
)
logging.info(f"mAP for audioset eval is: {mAP}")
logging.info("Done")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,411 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang)
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
"""
This script exports a transducer model from PyTorch to ONNX.
Usage of this script:
repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12
repo=$(basename $repo_url)
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
pushd $repo/exp
git lfs pull --include pretrained.pt
ln -s pretrained.pt epoch-99.pt
popd
python3 zipformer/export-onnx.py \
--exp-dir $repo/exp \
--epoch 99 \
--avg 1 \
--use-averaged-model 0
pushd $repo/exp
mv model-epoch-99-avg-1.onnx model.onnx
mv model-epoch-99-avg-1.int8.onnx model.int8.onnx
popd
See ./onnx_pretrained.py
use the exported ONNX models.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict
import onnx
import onnxoptimizer
import torch
import torch.nn as nn
from onnxruntime.quantization import QuantType, quantize_dynamic
from onnxsim import simplify
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_model, get_params
from zipformer import Zipformer2
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import make_pad_mask, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
add_model_arguments(parser)
return parser
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
onnx.save(model, filename)
class OnnxAudioTagger(nn.Module):
"""A wrapper for Zipformer audio tagger"""
def __init__(
self, encoder: Zipformer2, encoder_embed: nn.Module, classifier: nn.Linear
):
"""
Args:
encoder:
A Zipformer encoder.
encoder_proj:
The projection layer for encoder from the joiner.
"""
super().__init__()
self.encoder = encoder
self.encoder_embed = encoder_embed
self.classifier = classifier
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> torch.Tensor:
"""Please see the help information of Zipformer.forward
Args:
x:
A 3-D tensor of shape (N, T, C)
x_lens:
A 1-D tensor of shape (N,). Its dtype is torch.int64
Returns:
Return a tensor containing:
- probs, A 2-D tensor of shape (N, num_classes)
"""
x, x_lens = self.encoder_embed(x, x_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (N,T,C)
logits = self.classifier(encoder_out) # (N, T, num_classes)
# Note that this is slightly different from model.py for better
# support of onnx
logits = logits.mean(dim=1)
probs = logits.sigmoid()
return probs
def export_audio_tagging_model_onnx(
model: OnnxAudioTagger,
filename: str,
opset_version: int = 11,
) -> None:
"""Export the given encoder model to ONNX format.
The exported model has two inputs:
- x, a tensor of shape (N, T, C); dtype is torch.float32
- x_lens, a tensor of shape (N,); dtype is torch.int64
and it has two outputs:
- encoder_out, a tensor of shape (N, T', joiner_dim)
- encoder_out_lens, a tensor of shape (N,)
Args:
model:
The input encoder model
filename:
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
x = torch.zeros(1, 200, 80, dtype=torch.float32)
x_lens = torch.tensor([200], dtype=torch.int64)
model = torch.jit.trace(model, (x, x_lens))
torch.onnx.export(
model,
(x, x_lens),
filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "x_lens"],
output_names=["logits"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"probs": {0: "N"},
},
)
meta_data = {
"model_type": "zipformer2",
"version": "1",
"model_author": "k2-fsa",
"comment": "zipformer2 audio tagger",
"url": "https://github.com/k2-fsa/icefall/tree/master/egs/audioset/AT/zipformer",
}
logging.info(f"meta_data: {meta_data}")
add_meta_data(filename=filename, meta_data=meta_data)
def optimize_model(filename):
# see
# https://github.com/microsoft/onnxruntime/issues/1899#issuecomment-534806537
# and
# https://github.com/onnx/onnx/issues/582#issuecomment-937788108
# and
# https://github.com/onnx/optimizer/issues/110
# and
# https://qiita.com/Yossy_Hal/items/34f3b2aef2199baf7f5f
passes = ["eliminate_unused_initializer"]
onnx_model = onnx.load(filename)
onnx_model = onnxoptimizer.optimize(onnx_model, passes)
model_simp, check = simplify(onnx_model)
if check:
logging.info("Simplified the model!")
onnx_model = model_simp
else:
logging.info("Failed to simplify the model!")
onnx.save(onnx_model, filename)
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info(params)
logging.info("About to create model")
model = get_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
model = OnnxAudioTagger(
encoder=model.encoder,
encoder_embed=model.encoder_embed,
classifier=model.classifier,
)
model_num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"total parameters: {model_num_param}")
if params.iter > 0:
suffix = f"iter-{params.iter}"
else:
suffix = f"epoch-{params.epoch}"
suffix += f"-avg-{params.avg}"
opset_version = 13
logging.info("Exporting audio tagging model")
model_filename = params.exp_dir / f"model-{suffix}.onnx"
export_audio_tagging_model_onnx(
model,
model_filename,
opset_version=opset_version,
)
optimize_model(model_filename)
logging.info(f"Exported audio tagging model to {model_filename}")
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
logging.info("Generate int8 quantization models")
model_filename_int8 = params.exp_dir / f"model-{suffix}.int8.onnx"
quantize_dynamic(
model_input=model_filename,
model_output=model_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
optimize_model(model_filename_int8)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,340 @@
#!/usr/bin/env python3
#
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Wei Kang,
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
Note: This is an example for AudioSet dataset, if you are using different
dataset, you should change the argument values according to your dataset.
(1) Export to torchscript model using torch.jit.script()
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--epoch 30 \
--avg 9 \
--jit 1
It will generate a file `jit_script.pt` in the given `exp_dir`. You can later
load it by `torch.jit.load("jit_script.pt")`.
Check ./jit_pretrained.py for its usage.
Check https://github.com/k2-fsa/sherpa
and https://github.com/k2-fsa/sherpa-onnx
for how to use the exported models outside of icefall.
(2) Export `model.state_dict()`
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--epoch 30 \
--avg 9
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`.
To use the generated file with `zipformer/evaluate.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/audioset/AT
./zipformer/evaluate.py \
--exp-dir ./zipformer/exp \
--use-averaged-model False \
--epoch 9999 \
--avg 1 \
--max-duration 600
Check ./pretrained.py for its usage.
"""
import argparse
import logging
from pathlib import Path
from typing import Tuple
import torch
from scaling_converter import convert_scaled_to_non_scaled
from torch import Tensor, nn
from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import make_pad_mask, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
It will generate a file named jit_script.pt.
Check ./jit_pretrained.py for how to use it.
""",
)
add_model_arguments(parser)
return parser
class EncoderModel(nn.Module):
"""A wrapper for encoder and encoder_embed"""
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
super().__init__()
self.encoder = encoder
self.encoder_embed = encoder_embed
def forward(
self, features: Tensor, feature_lengths: Tensor
) -> Tuple[Tensor, Tensor]:
"""
Args:
features: (N, T, C)
feature_lengths: (N,)
"""
x, x_lens = self.encoder_embed(features, feature_lengths)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return encoder_out, encoder_out_lens
class Classifier(nn.Module):
"""A wrapper for audio tagging classifier"""
def __init__(self, classifier: nn.Module) -> None:
super().__init__()
self.classifier = classifier
def forward(self, encoder_out: Tensor, encoder_out_lens: Tensor):
"""
Args:
encoder_out:
A 3-D tensor of shape (N, T, C).
encoder_out_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
"""
logits = self.classifier(encoder_out) # (N, T, num_classes)
padding_mask = make_pad_mask(encoder_out_lens)
logits[padding_mask] = 0
logits = logits.sum(dim=1) # mask the padding frames
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(
logits
) # normalize the logits
return logits
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
logging.info(f"device: {device}")
logging.info(params)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.eval()
if params.jit is True:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
model.encoder = EncoderModel(model.encoder, model.encoder_embed)
model.classifier = Classifier(model.classifier)
filename = "jit_script.pt"
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
model.save(str(params.exp_dir / filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torchscript. Export model.state_dict()")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,191 @@
#!/usr/bin/env python3
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao)
# 2024 Xiaoyu Yang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads torchscript models, exported by `torch.jit.script()`
and uses them to decode waves.
You can use the following command to get the exported models:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--epoch 30 \
--avg 9 \
--jit 1
Usage of this script:
repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12
repo=$(basename $repo_url)
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
pushd $repo/exp
git lfs pull --include jit_script.pt
popd
python3 zipformer/jit_pretrained.py \
--nn-model-filename $repo/exp/jit_script.pt \
--label-dict $repo/data/class_labels_indices.csv \
$repo/test_wavs/1.wav \
$repo/test_wavs/2.wav \
$repo/test_wavs/3.wav \
$repo/test_wavs/4.wav
"""
import argparse
import csv
import logging
import math
from typing import List
import kaldifeat
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--nn-model-filename",
type=str,
required=True,
help="Path to the torchscript model cpu_jit.pt",
)
parser.add_argument(
"--label-dict",
type=str,
help="""class_labels_indices.csv.""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float = 16000
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
model = torch.jit.load(args.nn_model_filename)
model.eval()
model.to(device)
# get the label dictionary
label_dict = {}
with open(args.label_dict, "r") as f:
reader = csv.reader(f, delimiter=",")
for i, row in enumerate(reader):
if i == 0:
continue
label_dict[int(row[0])] = row[2]
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files,
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10),
)
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(
features=features,
feature_lengths=feature_lengths,
)
logits = model.classifier(encoder_out, encoder_out_lens)
for filename, logit in zip(args.sound_files, logits):
topk_prob, topk_index = logit.sigmoid().topk(5)
topk_labels = [label_dict[index.item()] for index in topk_index]
logging.info(
f"{filename}: Top 5 predicted labels are {topk_labels} with "
f"probability of {topk_prob.tolist()}"
)
logging.info("Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,153 @@
# Copyright 2021-2023 Xiaomi Corp. (authors: Xiaoyu Yang,
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from icefall.utils import make_pad_mask
class AudioTaggingModel(nn.Module):
def __init__(
self,
encoder_embed: nn.Module,
encoder: EncoderInterface,
encoder_dim: int = 384,
num_events: int = 527,
):
"""An audio tagging model
Args:
encoder_embed:
It is a Convolutional 2D subsampling module. It converts
an input of shape (N, T, idim) to an output of of shape
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
`logit_lens` of shape (N,).
encoder_dim:
Dimension of the encoder.
num_event:
The number of classes.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
self.encoder_embed = encoder_embed
self.encoder = encoder
self.encoder_dim = encoder_dim
self.classifier = nn.Sequential(
nn.Dropout(0.1),
nn.Linear(encoder_dim, num_events),
)
# for multi-class classification
self.criterion = torch.nn.BCEWithLogitsLoss(reduction="sum")
def forward_encoder(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute encoder outputs.
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
Returns:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
"""
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
x, x_lens = self.encoder_embed(x, x_lens)
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
return encoder_out, encoder_out_lens
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
target: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
target:
The ground truth label of audio events, could be many hot
Returns:
Return the binary crossentropy loss
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
# Compute encoder outputs
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
# Forward the speaker module
logits = self.forward_audio_tagging(
encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
) # (N, num_classes)
loss = self.criterion(logits, target)
return loss
def forward_audio_tagging(self, encoder_out, encoder_out_lens):
"""
Args:
encoder_out:
A 3-D tensor of shape (N, T, C).
encoder_out_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
Returns:
A 3-D tensor of shape (N, num_classes).
"""
logits = self.classifier(encoder_out) # (N, T, num_classes)
padding_mask = make_pad_mask(encoder_out_lens)
logits[padding_mask] = 0
logits = logits.sum(dim=1) # mask the padding frames
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(
logits
) # normalize the logits
return logits

View File

@ -0,0 +1,227 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
# 2022 Xiaomi Corp. (authors: Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads ONNX models and uses them to decode waves.
Usage of this script:
repo_url=https://huggingface.co/k2-fsa/sherpa-onnx-zipformer-audio-tagging-2024-04-09
repo=$(basename $repo_url)
git clone $repo_url
pushd $repo
git lfs pull --include "*.onnx"
popd
for m in model.onnx model.int8.onnx; do
python3 zipformer/onnx_pretrained.py \
--model-filename $repo/model.onnx \
--label-dict $repo/class_labels_indices.csv \
$repo/test_wavs/1.wav \
$repo/test_wavs/2.wav \
$repo/test_wavs/3.wav \
$repo/test_wavs/4.wav
done
"""
import argparse
import csv
import logging
import math
from typing import List
import kaldifeat
import onnxruntime as ort
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--model-filename",
type=str,
required=True,
help="Path to the onnx model. ",
)
parser.add_argument(
"--label-dict",
type=str,
help="""class_labels_indices.csv.""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
return parser
class OnnxModel:
def __init__(
self,
nn_model: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 4
self.session_opts = session_opts
self.init_model(nn_model)
def init_model(self, nn_model: str):
self.model = ort.InferenceSession(
nn_model,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
meta = self.model.get_modelmeta().custom_metadata_map
print(meta)
def __call__(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C)
x_lens:
A 2-D tensor of shape (N,). Its dtype is torch.int64
Returns:
Return a Tensor:
- probs, its shape is (N, num_classes)
"""
out = self.model.run(
[
self.model.get_outputs()[0].name,
],
{
self.model.get_inputs()[0].name: x.numpy(),
self.model.get_inputs()[1].name: x_lens.numpy(),
},
)
return torch.from_numpy(out[0])
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
model = OnnxModel(
nn_model=args.model_filename,
)
# get the label dictionary
label_dict = {}
with open(args.label_dict, "r") as f:
reader = csv.reader(f, delimiter=",")
for i, row in enumerate(reader):
if i == 0:
continue
label_dict[int(row[0])] = row[2]
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files,
expected_sample_rate=args.sample_rate,
)
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10),
)
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
probs = model(features, feature_lengths)
for filename, prob in zip(args.sound_files, probs):
topk_prob, topk_index = prob.topk(5)
topk_labels = [label_dict[index.item()] for index in topk_index]
logging.info(
f"{filename}: Top 5 predicted labels are {topk_labels} with "
f"probability of {topk_prob.tolist()}"
)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/optim.py

View File

@ -0,0 +1,202 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads a checkpoint and uses it to decode waves.
You can generate the checkpoint with the following command:
Note: This is an example for the AudioSet dataset, if you are using different
dataset, you should change the argument values according to your dataset.
Usage of this script:
repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12
repo=$(basename $repo_url)
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
pushd $repo/exp
git lfs pull --include pretrained.pt
popd
python3 zipformer/pretrained.py \
--checkpoint $repo/exp/pretrained.pt \
--label-dict $repo/data/class_labels_indices.csv \
$repo/test_wavs/1.wav \
$repo/test_wavs/2.wav \
$repo/test_wavs/3.wav \
$repo/test_wavs/4.wav
"""
import argparse
import csv
import logging
import math
from typing import List
import kaldifeat
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_model, get_params
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--label-dict",
type=str,
help="""class_labels_indices.csv.""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
# get the label dictionary
label_dict = {}
with open(params.label_dict, "r") as f:
reader = csv.reader(f, delimiter=",")
for i, row in enumerate(reader):
if i == 0:
continue
label_dict[int(row[0])] = row[2]
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)
# model forward and predict the audio events
encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths)
logits = model.forward_audio_tagging(encoder_out, encoder_out_lens)
for filename, logit in zip(args.sound_files, logits):
topk_prob, topk_index = logit.sigmoid().topk(5)
topk_labels = [label_dict[index.item()] for index in topk_index]
logging.info(
f"{filename}: Top 5 predicted labels are {topk_labels} with "
f"probability of {topk_prob.tolist()}"
)
logging.info("Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/scaling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/scaling_converter.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/subsampling.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/zipformer.py

View File

@ -1,5 +1,74 @@
## Results ## Results
### GigaSpeech BPE training results (Pruned Stateless Transducer 7)
### Commonvoice Cantonese (zh-HK) Char training results (Zipformer)
See #1546 for more details.
Number of model parameters: 72526519, i.e., 72.53 M
The best CER, for CommonVoice 16.1 (cv-corpus-16.1-2023-12-06/zh-HK) is below:
| | Dev | Test | Note |
|----------------------|-------|------|--------------------|
| greedy_search | 1.17 | 1.22 | --epoch 24 --avg 5 |
| modified_beam_search | 0.98 | 1.11 | --epoch 24 --avg 5 |
| fast_beam_search | 1.08 | 1.27 | --epoch 24 --avg 5 |
When doing the cross-corpus validation on [MDCC](https://arxiv.org/abs/2201.02419) (w/o blank penalty),
the best CER is below:
| | Dev | Test | Note |
|----------------------|-------|------|--------------------|
| greedy_search | 42.40 | 42.03| --epoch 24 --avg 5 |
| modified_beam_search | 39.73 | 39.19| --epoch 24 --avg 5 |
| fast_beam_search | 42.14 | 41.98| --epoch 24 --avg 5 |
When doing the cross-corpus validation on [MDCC](https://arxiv.org/abs/2201.02419) (with blank penalty set to 2.2),
the best CER is below:
| | Dev | Test | Note |
|----------------------|-------|------|----------------------------------------|
| greedy_search | 39.19 | 39.09| --epoch 24 --avg 5 --blank-penalty 2.2 |
| modified_beam_search | 37.73 | 37.65| --epoch 24 --avg 5 --blank-penalty 2.2 |
| fast_beam_search | 37.73 | 37.74| --epoch 24 --avg 5 --blank-penalty 2.2 |
To reproduce the above result, use the following commands for training:
```bash
export CUDA_VISIBLE_DEVICES="0,1"
./zipformer/train_char.py \
--world-size 2 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--cv-manifest-dir data/zh-HK/fbank \
--language zh-HK \
--use-validated-set 1 \
--context-size 1 \
--max-duration 1000
```
and the following commands for decoding:
```bash
for method in greedy_search modified_beam_search fast_beam_search; do
./zipformer/decode_char.py \
--epoch 24 \
--avg 5 \
--decoding-method $method \
--exp-dir zipformer/exp \
--cv-manifest-dir data/zh-HK/fbank \
--context-size 1 \
--language zh-HK
done
```
Detailed experimental results and pre-trained model are available at:
<https://huggingface.co/zrjin/icefall-asr-commonvoice-zh-HK-zipformer-2024-03-20>
### CommonVoice English (en) BPE training results (Pruned Stateless Transducer 7)
#### [pruned_transducer_stateless7](./pruned_transducer_stateless7) #### [pruned_transducer_stateless7](./pruned_transducer_stateless7)
@ -7,14 +76,16 @@ See #997 for more details.
Number of model parameters: 70369391, i.e., 70.37 M Number of model parameters: 70369391, i.e., 70.37 M
Note that the result is obtained using GigaSpeech transcript trained BPE model
The best WER, as of 2023-04-17, for Common Voice English 13.0 (cv-corpus-13.0-2023-03-09/en) is below: The best WER, as of 2023-04-17, for Common Voice English 13.0 (cv-corpus-13.0-2023-03-09/en) is below:
Results are: Results are:
| | Dev | Test | | | Dev | Test |
|----------------------|-------|-------| |----------------------|-------|-------|
| greedy search | 9.96 | 12.54 | | greedy_search | 9.96 | 12.54 |
| modified beam search | 9.86 | 12.48 | | modified_beam_search | 9.86 | 12.48 |
To reproduce the above result, use the following commands for training: To reproduce the above result, use the following commands for training:
@ -55,10 +126,6 @@ and the following commands for decoding:
Pretrained model is available at Pretrained model is available at
<https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17> <https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17>
The tensorboard log for training is available at
<https://tensorboard.dev/experiment/j4pJQty6RMOkMJtRySREKw/>
### Commonvoice (fr) BPE training results (Pruned Stateless Transducer 7_streaming) ### Commonvoice (fr) BPE training results (Pruned Stateless Transducer 7_streaming)
#### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming) #### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming)
@ -73,9 +140,9 @@ Results are:
| decoding method | Test | | decoding method | Test |
|----------------------|-------| |----------------------|-------|
| greedy search | 9.95 | | greedy_search | 9.95 |
| modified beam search | 9.57 | | modified_beam_search | 9.57 |
| fast beam search | 9.67 | | fast_beam_search | 9.67 |
Note: This best result is trained on the full librispeech and gigaspeech, and then fine-tuned on the full commonvoice. Note: This best result is trained on the full librispeech and gigaspeech, and then fine-tuned on the full commonvoice.

View File

@ -1 +0,0 @@
../../../librispeech/ASR/local/compile_hlg.py

View File

@ -0,0 +1,168 @@
#!/usr/bin/env python3
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input lang_dir and generates HLG from
- H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
- L, the lexicon, built from lang_dir/L_disambig.pt
Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from data/lm/G_n_gram.fst.txt
The generated HLG is saved in $lang_dir/HLG.pt
"""
import argparse
import logging
from pathlib import Path
import k2
import torch
from icefall.lexicon import Lexicon
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lm",
type=str,
default="G_3_gram",
help="""Stem name for LM used in HLG compiling.
""",
)
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
""",
)
return parser.parse_args()
def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
"""
Args:
lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
lm:
The language stem base name.
Return:
An FSA representing HLG.
"""
lexicon = Lexicon(lang_dir)
max_token_id = max(lexicon.tokens)
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
if Path(f"{lang_dir}/lm/{lm}.pt").is_file():
logging.info(f"Loading pre-compiled {lm}")
d = torch.load(f"{lang_dir}/lm/{lm}.pt")
G = k2.Fsa.from_dict(d)
else:
logging.info(f"Loading {lm}.fst.txt")
with open(f"{lang_dir}/lm/{lm}.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
torch.save(G.as_dict(), f"{lang_dir}/lm/{lm}.pt")
first_token_disambig_id = lexicon.token_table["#0"]
first_word_disambig_id = lexicon.word_table["#0"]
L = k2.arc_sort(L)
G = k2.arc_sort(G)
logging.info("Intersecting L and G")
LG = k2.compose(L, G)
logging.info(f"LG shape: {LG.shape}")
logging.info("Connecting LG")
LG = k2.connect(LG)
logging.info(f"LG shape after k2.connect: {LG.shape}")
logging.info(type(LG.aux_labels))
logging.info("Determinizing LG")
LG = k2.determinize(LG)
logging.info(type(LG.aux_labels))
logging.info("Connecting LG after k2.determinize")
LG = k2.connect(LG)
logging.info("Removing disambiguation symbols on LG")
# LG.labels[LG.labels >= first_token_disambig_id] = 0
# see https://github.com/k2-fsa/k2/pull/1140
labels = LG.labels
labels[labels >= first_token_disambig_id] = 0
LG.labels = labels
assert isinstance(LG.aux_labels, k2.RaggedTensor)
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
LG = k2.connect(LG)
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)
logging.info("Composing H and LG")
# CAUTION: The name of the inner_labels is fixed
# to `tokens`. If you want to change it, please
# also change other places in icefall that are using
# it.
HLG = k2.compose(H, LG, inner_labels="tokens")
logging.info("Connecting LG")
HLG = k2.connect(HLG)
logging.info("Arc sorting LG")
HLG = k2.arc_sort(HLG)
logging.info(f"HLG.shape: {HLG.shape}")
return HLG
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
if (lang_dir / "HLG.pt").is_file():
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
return
logging.info(f"Processing {lang_dir}")
HLG = compile_HLG(lang_dir, args.lm)
logging.info(f"Saving HLG.pt to {lang_dir}")
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

Some files were not shown because too many files have changed in this diff Show More