mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
Merge remote-tracking branch 'dan/master' into doc-force-alignment-kaldi
This commit is contained in:
commit
cb21b878c0
1
.github/scripts/.gitignore
vendored
Normal file
1
.github/scripts/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
piper_phonemize.html
|
94
.github/scripts/audioset/AT/run.sh
vendored
Executable file
94
.github/scripts/audioset/AT/run.sh
vendored
Executable file
@ -0,0 +1,94 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -ex
|
||||||
|
|
||||||
|
python3 -m pip install onnxoptimizer onnxsim
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
cd egs/audioset/AT
|
||||||
|
|
||||||
|
function test_pretrained() {
|
||||||
|
repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
pushd $repo/exp
|
||||||
|
git lfs pull --include pretrained.pt
|
||||||
|
ln -s pretrained.pt epoch-99.pt
|
||||||
|
ls -lh
|
||||||
|
popd
|
||||||
|
|
||||||
|
log "test pretrained.pt"
|
||||||
|
|
||||||
|
python3 zipformer/pretrained.py \
|
||||||
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
|
--label-dict $repo/data/class_labels_indices.csv \
|
||||||
|
$repo/test_wavs/1.wav \
|
||||||
|
$repo/test_wavs/2.wav \
|
||||||
|
$repo/test_wavs/3.wav \
|
||||||
|
$repo/test_wavs/4.wav
|
||||||
|
|
||||||
|
log "test jit export"
|
||||||
|
ls -lh $repo/exp/
|
||||||
|
python3 zipformer/export.py \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--jit 1
|
||||||
|
ls -lh $repo/exp/
|
||||||
|
|
||||||
|
log "test jit models"
|
||||||
|
python3 zipformer/jit_pretrained.py \
|
||||||
|
--nn-model-filename $repo/exp/jit_script.pt \
|
||||||
|
--label-dict $repo/data/class_labels_indices.csv \
|
||||||
|
$repo/test_wavs/1.wav \
|
||||||
|
$repo/test_wavs/2.wav \
|
||||||
|
$repo/test_wavs/3.wav \
|
||||||
|
$repo/test_wavs/4.wav
|
||||||
|
|
||||||
|
log "test onnx export"
|
||||||
|
ls -lh $repo/exp/
|
||||||
|
python3 zipformer/export-onnx.py \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--use-averaged-model 0
|
||||||
|
|
||||||
|
ls -lh $repo/exp/
|
||||||
|
|
||||||
|
pushd $repo/exp/
|
||||||
|
mv model-epoch-99-avg-1.onnx model.onnx
|
||||||
|
mv model-epoch-99-avg-1.int8.onnx model.int8.onnx
|
||||||
|
popd
|
||||||
|
|
||||||
|
ls -lh $repo/exp/
|
||||||
|
|
||||||
|
log "test onnx models"
|
||||||
|
for m in model.onnx model.int8.onnx; do
|
||||||
|
log "$m"
|
||||||
|
python3 zipformer/onnx_pretrained.py \
|
||||||
|
--model-filename $repo/exp/model.onnx \
|
||||||
|
--label-dict $repo/data/class_labels_indices.csv \
|
||||||
|
$repo/test_wavs/1.wav \
|
||||||
|
$repo/test_wavs/2.wav \
|
||||||
|
$repo/test_wavs/3.wav \
|
||||||
|
$repo/test_wavs/4.wav
|
||||||
|
done
|
||||||
|
|
||||||
|
log "prepare data for uploading to huggingface"
|
||||||
|
dst=/icefall/model-onnx
|
||||||
|
mkdir -p $dst
|
||||||
|
cp -v $repo/exp/*.onnx $dst/
|
||||||
|
cp -v $repo/data/* $dst/
|
||||||
|
cp -av $repo/test_wavs $dst
|
||||||
|
|
||||||
|
ls -lh $dst
|
||||||
|
ls -lh $dst/test_wavs
|
||||||
|
}
|
||||||
|
|
||||||
|
test_pretrained
|
8
.github/scripts/docker/Dockerfile
vendored
8
.github/scripts/docker/Dockerfile
vendored
@ -11,6 +11,7 @@ ARG _KALDIFEAT_VERSION="${KALDIFEAT_VERSION}+cpu.torch${TORCH_VERSION}"
|
|||||||
|
|
||||||
RUN apt-get update -y && \
|
RUN apt-get update -y && \
|
||||||
apt-get install -qq -y \
|
apt-get install -qq -y \
|
||||||
|
cmake \
|
||||||
ffmpeg \
|
ffmpeg \
|
||||||
git \
|
git \
|
||||||
git-lfs \
|
git-lfs \
|
||||||
@ -35,7 +36,9 @@ RUN pip install --no-cache-dir \
|
|||||||
\
|
\
|
||||||
git+https://github.com/lhotse-speech/lhotse \
|
git+https://github.com/lhotse-speech/lhotse \
|
||||||
kaldifeat==${_KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cpu.html \
|
kaldifeat==${_KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cpu.html \
|
||||||
|
cython \
|
||||||
dill \
|
dill \
|
||||||
|
espnet_tts_frontend \
|
||||||
graphviz \
|
graphviz \
|
||||||
kaldi-decoder \
|
kaldi-decoder \
|
||||||
kaldi_native_io \
|
kaldi_native_io \
|
||||||
@ -44,10 +47,15 @@ RUN pip install --no-cache-dir \
|
|||||||
kaldilm \
|
kaldilm \
|
||||||
matplotlib \
|
matplotlib \
|
||||||
multi_quantization \
|
multi_quantization \
|
||||||
|
numba \
|
||||||
numpy \
|
numpy \
|
||||||
|
onnxoptimizer \
|
||||||
|
onnxsim \
|
||||||
onnx \
|
onnx \
|
||||||
onnxmltools \
|
onnxmltools \
|
||||||
onnxruntime \
|
onnxruntime \
|
||||||
|
piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html \
|
||||||
|
pypinyin==0.50.0 \
|
||||||
pytest \
|
pytest \
|
||||||
sentencepiece>=0.1.96 \
|
sentencepiece>=0.1.96 \
|
||||||
six \
|
six \
|
||||||
|
44
.github/scripts/docker/generate_build_matrix.py
vendored
44
.github/scripts/docker/generate_build_matrix.py
vendored
@ -6,8 +6,8 @@ import json
|
|||||||
|
|
||||||
|
|
||||||
def version_gt(a, b):
|
def version_gt(a, b):
|
||||||
a_major, a_minor = a.split(".")[:2]
|
a_major, a_minor = list(map(int, a.split(".")))[:2]
|
||||||
b_major, b_minor = b.split(".")[:2]
|
b_major, b_minor = list(map(int, b.split(".")))[:2]
|
||||||
if a_major > b_major:
|
if a_major > b_major:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -18,8 +18,8 @@ def version_gt(a, b):
|
|||||||
|
|
||||||
|
|
||||||
def version_ge(a, b):
|
def version_ge(a, b):
|
||||||
a_major, a_minor = a.split(".")[:2]
|
a_major, a_minor = list(map(int, a.split(".")))[:2]
|
||||||
b_major, b_minor = b.split(".")[:2]
|
b_major, b_minor = list(map(int, b.split(".")))[:2]
|
||||||
if a_major > b_major:
|
if a_major > b_major:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -43,11 +43,16 @@ def get_torchaudio_version(torch_version):
|
|||||||
|
|
||||||
|
|
||||||
def get_matrix():
|
def get_matrix():
|
||||||
k2_version = "1.24.4.dev20231220"
|
k2_version = "1.24.4.dev20240223"
|
||||||
kaldifeat_version = "1.25.3.dev20231221"
|
kaldifeat_version = "1.25.4.dev20240223"
|
||||||
version = "1.2"
|
version = "20240606"
|
||||||
python_version = ["3.8", "3.9", "3.10", "3.11"]
|
python_version = ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||||
torch_version = ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"]
|
torch_version = []
|
||||||
|
torch_version += ["1.13.0", "1.13.1"]
|
||||||
|
torch_version += ["2.0.0", "2.0.1"]
|
||||||
|
torch_version += ["2.1.0", "2.1.1", "2.1.2"]
|
||||||
|
torch_version += ["2.2.0", "2.2.1", "2.2.2"]
|
||||||
|
torch_version += ["2.3.0", "2.3.1"]
|
||||||
|
|
||||||
matrix = []
|
matrix = []
|
||||||
for p in python_version:
|
for p in python_version:
|
||||||
@ -57,10 +62,27 @@ def get_matrix():
|
|||||||
if version_gt(p, "3.10") and not version_gt(t, "2.0"):
|
if version_gt(p, "3.10") and not version_gt(t, "2.0"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# only torch>=2.2.0 supports python 3.12
|
||||||
|
if version_gt(p, "3.11") and not version_gt(t, "2.1"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
k2_version_2 = k2_version
|
||||||
|
kaldifeat_version_2 = kaldifeat_version
|
||||||
|
|
||||||
|
if t == "2.2.2":
|
||||||
|
k2_version_2 = "1.24.4.dev20240328"
|
||||||
|
kaldifeat_version_2 = "1.25.4.dev20240329"
|
||||||
|
elif t == "2.3.0":
|
||||||
|
k2_version_2 = "1.24.4.dev20240425"
|
||||||
|
kaldifeat_version_2 = "1.25.4.dev20240425"
|
||||||
|
elif t == "2.3.1":
|
||||||
|
k2_version_2 = "1.24.4.dev20240606"
|
||||||
|
kaldifeat_version_2 = "1.25.4.dev20240606"
|
||||||
|
|
||||||
matrix.append(
|
matrix.append(
|
||||||
{
|
{
|
||||||
"k2-version": k2_version,
|
"k2-version": k2_version_2,
|
||||||
"kaldifeat-version": kaldifeat_version,
|
"kaldifeat-version": kaldifeat_version_2,
|
||||||
"version": version,
|
"version": version,
|
||||||
"python-version": p,
|
"python-version": p,
|
||||||
"torch-version": t,
|
"torch-version": t,
|
||||||
|
29
.github/scripts/generate-piper-phonemize-page.py
vendored
Executable file
29
.github/scripts/generate-piper-phonemize-page.py
vendored
Executable file
@ -0,0 +1,29 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
prefix = (
|
||||||
|
"https://github.com/csukuangfj/piper-phonemize/releases/download/2023.12.5/"
|
||||||
|
)
|
||||||
|
files = [
|
||||||
|
"piper_phonemize-1.2.0-cp310-cp310-macosx_10_14_x86_64.whl",
|
||||||
|
"piper_phonemize-1.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
|
||||||
|
"piper_phonemize-1.2.0-cp311-cp311-macosx_10_14_x86_64.whl",
|
||||||
|
"piper_phonemize-1.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
|
||||||
|
"piper_phonemize-1.2.0-cp312-cp312-macosx_10_14_x86_64.whl",
|
||||||
|
"piper_phonemize-1.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
|
||||||
|
"piper_phonemize-1.2.0-cp37-cp37m-macosx_10_14_x86_64.whl",
|
||||||
|
"piper_phonemize-1.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
|
||||||
|
"piper_phonemize-1.2.0-cp38-cp38-macosx_10_14_x86_64.whl",
|
||||||
|
"piper_phonemize-1.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
|
||||||
|
"piper_phonemize-1.2.0-cp39-cp39-macosx_10_14_x86_64.whl",
|
||||||
|
"piper_phonemize-1.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
|
||||||
|
]
|
||||||
|
with open("piper_phonemize.html", "w") as f:
|
||||||
|
for file in files:
|
||||||
|
url = prefix + file
|
||||||
|
f.write(f'<a href="{url}">{file}</a><br/>\n')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
47
.github/scripts/librispeech/ASR/run.sh
vendored
47
.github/scripts/librispeech/ASR/run.sh
vendored
@ -15,9 +15,9 @@ function prepare_data() {
|
|||||||
# cause OOM error for CI later.
|
# cause OOM error for CI later.
|
||||||
mkdir -p download/lm
|
mkdir -p download/lm
|
||||||
pushd download/lm
|
pushd download/lm
|
||||||
wget -q http://www.openslr.org/resources/11/librispeech-vocab.txt
|
wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-lm-norm.txt.gz
|
||||||
wget -q http://www.openslr.org/resources/11/librispeech-lexicon.txt
|
wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-lexicon.txt
|
||||||
wget -q http://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz
|
wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-vocab.txt
|
||||||
ls -lh
|
ls -lh
|
||||||
gunzip librispeech-lm-norm.txt.gz
|
gunzip librispeech-lm-norm.txt.gz
|
||||||
|
|
||||||
@ -64,6 +64,46 @@ function run_diagnostics() {
|
|||||||
--print-diagnostics 1
|
--print-diagnostics 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function test_streaming_zipformer_ctc_hlg() {
|
||||||
|
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18
|
||||||
|
|
||||||
|
log "Downloading pre-trained model from $repo_url"
|
||||||
|
git lfs install
|
||||||
|
git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
rm $repo/exp-ctc-rnnt-small/*.onnx
|
||||||
|
ls -lh $repo/exp-ctc-rnnt-small
|
||||||
|
|
||||||
|
# export models to onnx
|
||||||
|
./zipformer/export-onnx-streaming-ctc.py \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 3 \
|
||||||
|
--exp-dir $repo/exp-ctc-rnnt-small \
|
||||||
|
--causal 1 \
|
||||||
|
--use-ctc 1 \
|
||||||
|
--chunk-size 16 \
|
||||||
|
--left-context-frames 128 \
|
||||||
|
\
|
||||||
|
--num-encoder-layers 2,2,2,2,2,2 \
|
||||||
|
--feedforward-dim 512,768,768,768,768,768 \
|
||||||
|
--encoder-dim 192,256,256,256,256,256 \
|
||||||
|
--encoder-unmasked-dim 192,192,192,192,192,192
|
||||||
|
|
||||||
|
ls -lh $repo/exp-ctc-rnnt-small
|
||||||
|
|
||||||
|
for wav in 0.wav 1.wav 8k.wav; do
|
||||||
|
python3 ./zipformer/onnx_pretrained_ctc_HLG_streaming.py \
|
||||||
|
--nn-model $repo/exp-ctc-rnnt-small/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \
|
||||||
|
--words $repo/data/lang_bpe_500/words.txt \
|
||||||
|
--HLG $repo/data/lang_bpe_500/HLG.fst \
|
||||||
|
$repo/test_wavs/$wav
|
||||||
|
done
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
|
}
|
||||||
|
|
||||||
function test_pruned_transducer_stateless_2022_03_12() {
|
function test_pruned_transducer_stateless_2022_03_12() {
|
||||||
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
|
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
|
||||||
|
|
||||||
@ -1577,6 +1617,7 @@ function test_transducer_bpe_500_2021_12_23() {
|
|||||||
|
|
||||||
prepare_data
|
prepare_data
|
||||||
run_diagnostics
|
run_diagnostics
|
||||||
|
test_streaming_zipformer_ctc_hlg
|
||||||
test_pruned_transducer_stateless_2022_03_12
|
test_pruned_transducer_stateless_2022_03_12
|
||||||
test_pruned_transducer_stateless2_2022_04_29
|
test_pruned_transducer_stateless2_2022_04_29
|
||||||
test_pruned_transducer_stateless3_2022_04_29
|
test_pruned_transducer_stateless3_2022_04_29
|
||||||
|
157
.github/scripts/ljspeech/TTS/run.sh
vendored
Executable file
157
.github/scripts/ljspeech/TTS/run.sh
vendored
Executable file
@ -0,0 +1,157 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -ex
|
||||||
|
|
||||||
|
python3 -m pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html
|
||||||
|
python3 -m pip install espnet_tts_frontend
|
||||||
|
python3 -m pip install numba
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
cd egs/ljspeech/TTS
|
||||||
|
|
||||||
|
sed -i.bak s/600/8/g ./prepare.sh
|
||||||
|
sed -i.bak s/"first 100"/"first 3"/g ./prepare.sh
|
||||||
|
sed -i.bak s/500/5/g ./prepare.sh
|
||||||
|
git diff
|
||||||
|
|
||||||
|
function prepare_data() {
|
||||||
|
# We have created a subset of the data for testing
|
||||||
|
#
|
||||||
|
mkdir download
|
||||||
|
pushd download
|
||||||
|
wget -q https://huggingface.co/csukuangfj/ljspeech-subset-for-ci-test/resolve/main/LJSpeech-1.1.tar.bz2
|
||||||
|
tar xvf LJSpeech-1.1.tar.bz2
|
||||||
|
popd
|
||||||
|
|
||||||
|
./prepare.sh
|
||||||
|
tree .
|
||||||
|
}
|
||||||
|
|
||||||
|
function train() {
|
||||||
|
pushd ./vits
|
||||||
|
sed -i.bak s/200/3/g ./train.py
|
||||||
|
git diff .
|
||||||
|
popd
|
||||||
|
|
||||||
|
for t in low medium high; do
|
||||||
|
./vits/train.py \
|
||||||
|
--exp-dir vits/exp-$t \
|
||||||
|
--model-type $t \
|
||||||
|
--num-epochs 1 \
|
||||||
|
--save-every-n 1 \
|
||||||
|
--num-buckets 2 \
|
||||||
|
--tokens data/tokens.txt \
|
||||||
|
--max-duration 20
|
||||||
|
|
||||||
|
ls -lh vits/exp-$t
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
function infer() {
|
||||||
|
for t in low medium high; do
|
||||||
|
./vits/infer.py \
|
||||||
|
--num-buckets 2 \
|
||||||
|
--model-type $t \
|
||||||
|
--epoch 1 \
|
||||||
|
--exp-dir ./vits/exp-$t \
|
||||||
|
--tokens data/tokens.txt \
|
||||||
|
--max-duration 20
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
function export_onnx() {
|
||||||
|
for t in low medium high; do
|
||||||
|
./vits/export-onnx.py \
|
||||||
|
--model-type $t \
|
||||||
|
--epoch 1 \
|
||||||
|
--exp-dir ./vits/exp-$t \
|
||||||
|
--tokens data/tokens.txt
|
||||||
|
|
||||||
|
ls -lh vits/exp-$t/
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
function test_medium() {
|
||||||
|
git clone https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-medium-2024-03-12
|
||||||
|
|
||||||
|
./vits/export-onnx.py \
|
||||||
|
--model-type medium \
|
||||||
|
--epoch 820 \
|
||||||
|
--exp-dir ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp \
|
||||||
|
--tokens ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt
|
||||||
|
|
||||||
|
ls -lh ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp
|
||||||
|
|
||||||
|
./vits/test_onnx.py \
|
||||||
|
--model-filename ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp/vits-epoch-820.onnx \
|
||||||
|
--tokens ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt \
|
||||||
|
--output-filename /icefall/test-medium.wav
|
||||||
|
|
||||||
|
ls -lh /icefall/test-medium.wav
|
||||||
|
|
||||||
|
d=/icefall/vits-icefall-en_US-ljspeech-medium
|
||||||
|
mkdir $d
|
||||||
|
cp -v ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt $d/
|
||||||
|
cp -v ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp/vits-epoch-820.onnx $d/model.onnx
|
||||||
|
|
||||||
|
rm -rf icefall-tts-ljspeech-vits-medium-2024-03-12
|
||||||
|
|
||||||
|
pushd $d
|
||||||
|
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
|
||||||
|
tar xf espeak-ng-data.tar.bz2
|
||||||
|
rm espeak-ng-data.tar.bz2
|
||||||
|
cd ..
|
||||||
|
tar cjf vits-icefall-en_US-ljspeech-medium.tar.bz2 vits-icefall-en_US-ljspeech-medium
|
||||||
|
rm -rf vits-icefall-en_US-ljspeech-medium
|
||||||
|
ls -lh *.tar.bz2
|
||||||
|
popd
|
||||||
|
}
|
||||||
|
|
||||||
|
function test_low() {
|
||||||
|
git clone https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-low-2024-03-12
|
||||||
|
|
||||||
|
./vits/export-onnx.py \
|
||||||
|
--model-type low \
|
||||||
|
--epoch 1600 \
|
||||||
|
--exp-dir ./icefall-tts-ljspeech-vits-low-2024-03-12/exp \
|
||||||
|
--tokens ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt
|
||||||
|
|
||||||
|
ls -lh ./icefall-tts-ljspeech-vits-low-2024-03-12/exp
|
||||||
|
|
||||||
|
./vits/test_onnx.py \
|
||||||
|
--model-filename ./icefall-tts-ljspeech-vits-low-2024-03-12/exp/vits-epoch-1600.onnx \
|
||||||
|
--tokens ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt \
|
||||||
|
--output-filename /icefall/test-low.wav
|
||||||
|
|
||||||
|
ls -lh /icefall/test-low.wav
|
||||||
|
|
||||||
|
d=/icefall/vits-icefall-en_US-ljspeech-low
|
||||||
|
mkdir $d
|
||||||
|
cp -v ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt $d/
|
||||||
|
cp -v ./icefall-tts-ljspeech-vits-low-2024-03-12/exp/vits-epoch-1600.onnx $d/model.onnx
|
||||||
|
|
||||||
|
rm -rf icefall-tts-ljspeech-vits-low-2024-03-12
|
||||||
|
|
||||||
|
pushd $d
|
||||||
|
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
|
||||||
|
tar xf espeak-ng-data.tar.bz2
|
||||||
|
rm espeak-ng-data.tar.bz2
|
||||||
|
cd ..
|
||||||
|
tar cjf vits-icefall-en_US-ljspeech-low.tar.bz2 vits-icefall-en_US-ljspeech-low
|
||||||
|
rm -rf vits-icefall-en_US-ljspeech-low
|
||||||
|
ls -lh *.tar.bz2
|
||||||
|
popd
|
||||||
|
}
|
||||||
|
|
||||||
|
prepare_data
|
||||||
|
train
|
||||||
|
infer
|
||||||
|
export_onnx
|
||||||
|
rm -rf vits/exp-{low,medium,high}
|
||||||
|
test_medium
|
||||||
|
test_low
|
137
.github/workflows/audioset.yml
vendored
Normal file
137
.github/workflows/audioset.yml
vendored
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
name: audioset
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: audioset-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
generate_build_matrix:
|
||||||
|
if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa'
|
||||||
|
# see https://github.com/pytorch/pytorch/pull/50633
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
outputs:
|
||||||
|
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
- name: Generating build matrix
|
||||||
|
id: set-matrix
|
||||||
|
run: |
|
||||||
|
# outputting for debugging purposes
|
||||||
|
python ./.github/scripts/docker/generate_build_matrix.py
|
||||||
|
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py)
|
||||||
|
echo "::set-output name=matrix::${MATRIX}"
|
||||||
|
|
||||||
|
audioset:
|
||||||
|
needs: generate_build_matrix
|
||||||
|
name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }}
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Free space
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
ls -lh
|
||||||
|
df -h
|
||||||
|
rm -rf /opt/hostedtoolcache
|
||||||
|
df -h
|
||||||
|
echo "pwd: $PWD"
|
||||||
|
echo "github.workspace ${{ github.workspace }}"
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
uses: addnab/docker-run-action@v3
|
||||||
|
with:
|
||||||
|
image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }}
|
||||||
|
options: |
|
||||||
|
--volume ${{ github.workspace }}/:/icefall
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PYTHONPATH=/icefall:$PYTHONPATH
|
||||||
|
cd /icefall
|
||||||
|
git config --global --add safe.directory /icefall
|
||||||
|
|
||||||
|
.github/scripts/audioset/AT/run.sh
|
||||||
|
|
||||||
|
- name: Show model files
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
sudo chown -R runner ./model-onnx
|
||||||
|
ls -lh ./model-onnx
|
||||||
|
chmod -x ./model-onnx/class_labels_indices.csv
|
||||||
|
|
||||||
|
echo "----------"
|
||||||
|
ls -lh ./model-onnx/*
|
||||||
|
|
||||||
|
- name: Upload model to huggingface
|
||||||
|
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push'
|
||||||
|
env:
|
||||||
|
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||||
|
uses: nick-fields/retry@v3
|
||||||
|
with:
|
||||||
|
max_attempts: 20
|
||||||
|
timeout_seconds: 200
|
||||||
|
shell: bash
|
||||||
|
command: |
|
||||||
|
git config --global user.email "csukuangfj@gmail.com"
|
||||||
|
git config --global user.name "Fangjun Kuang"
|
||||||
|
|
||||||
|
rm -rf huggingface
|
||||||
|
export GIT_LFS_SKIP_SMUDGE=1
|
||||||
|
|
||||||
|
git clone https://huggingface.co/k2-fsa/sherpa-onnx-zipformer-audio-tagging-2024-04-09 huggingface
|
||||||
|
cd huggingface
|
||||||
|
git fetch
|
||||||
|
git pull
|
||||||
|
git merge -m "merge remote" --ff origin main
|
||||||
|
cp ../model-onnx/*.onnx ./
|
||||||
|
cp ../model-onnx/*.csv ./
|
||||||
|
cp -a ../model-onnx/test_wavs ./
|
||||||
|
ls -lh
|
||||||
|
git add .
|
||||||
|
git status
|
||||||
|
git commit -m "update models"
|
||||||
|
git status
|
||||||
|
|
||||||
|
git push https://csukuangfj:$HF_TOKEN@huggingface.co/k2-fsa/sherpa-onnx-zipformer-audio-tagging-2024-04-09 main || true
|
||||||
|
rm -rf huggingface
|
||||||
|
|
||||||
|
- name: Prepare for release
|
||||||
|
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
d=sherpa-onnx-zipformer-audio-tagging-2024-04-09
|
||||||
|
mv ./model-onnx $d
|
||||||
|
tar cjvf ${d}.tar.bz2 $d
|
||||||
|
ls -lh
|
||||||
|
|
||||||
|
- name: Release exported onnx models
|
||||||
|
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push'
|
||||||
|
uses: svenstaro/upload-release-action@v2
|
||||||
|
with:
|
||||||
|
file_glob: true
|
||||||
|
overwrite: true
|
||||||
|
file: sherpa-onnx-*.tar.bz2
|
||||||
|
repo_name: k2-fsa/sherpa-onnx
|
||||||
|
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
|
||||||
|
tag: audio-tagging-models
|
||||||
|
|
3
.github/workflows/build-doc.yml
vendored
3
.github/workflows/build-doc.yml
vendored
@ -56,11 +56,14 @@ jobs:
|
|||||||
- name: Build doc
|
- name: Build doc
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
.github/scripts/generate-piper-phonemize-page.py
|
||||||
cd docs
|
cd docs
|
||||||
python3 -m pip install -r ./requirements.txt
|
python3 -m pip install -r ./requirements.txt
|
||||||
make html
|
make html
|
||||||
touch build/html/.nojekyll
|
touch build/html/.nojekyll
|
||||||
|
|
||||||
|
cp -v ../piper_phonemize.html ./build/html/
|
||||||
|
|
||||||
- name: Deploy
|
- name: Deploy
|
||||||
uses: peaceiris/actions-gh-pages@v3
|
uses: peaceiris/actions-gh-pages@v3
|
||||||
with:
|
with:
|
||||||
|
2
.github/workflows/build-docker-image.yml
vendored
2
.github/workflows/build-docker-image.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
image: ["torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
|
image: ["torch2.3.1-cuda12.1", "torch2.3.1-cuda11.8", "torch2.2.2-cuda12.1", "torch2.2.2-cuda11.8", "torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
# refer to https://github.com/actions/checkout
|
# refer to https://github.com/actions/checkout
|
||||||
|
102
.github/workflows/ljspeech.yml
vendored
Normal file
102
.github/workflows/ljspeech.yml
vendored
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
name: ljspeech
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ljspeech-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
generate_build_matrix:
|
||||||
|
if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa'
|
||||||
|
# see https://github.com/pytorch/pytorch/pull/50633
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
outputs:
|
||||||
|
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
- name: Generating build matrix
|
||||||
|
id: set-matrix
|
||||||
|
run: |
|
||||||
|
# outputting for debugging purposes
|
||||||
|
python ./.github/scripts/docker/generate_build_matrix.py
|
||||||
|
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py)
|
||||||
|
echo "::set-output name=matrix::${MATRIX}"
|
||||||
|
|
||||||
|
ljspeech:
|
||||||
|
needs: generate_build_matrix
|
||||||
|
name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }}
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Free space
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
ls -lh
|
||||||
|
df -h
|
||||||
|
rm -rf /opt/hostedtoolcache
|
||||||
|
df -h
|
||||||
|
echo "pwd: $PWD"
|
||||||
|
echo "github.workspace ${{ github.workspace }}"
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
uses: addnab/docker-run-action@v3
|
||||||
|
with:
|
||||||
|
image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }}
|
||||||
|
options: |
|
||||||
|
--volume ${{ github.workspace }}/:/icefall
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PYTHONPATH=/icefall:$PYTHONPATH
|
||||||
|
cd /icefall
|
||||||
|
git config --global --add safe.directory /icefall
|
||||||
|
|
||||||
|
.github/scripts/ljspeech/TTS/run.sh
|
||||||
|
|
||||||
|
- name: display files
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
ls -lh
|
||||||
|
|
||||||
|
- uses: actions/upload-artifact@v4
|
||||||
|
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0'
|
||||||
|
with:
|
||||||
|
name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }}
|
||||||
|
path: ./*.wav
|
||||||
|
|
||||||
|
- uses: actions/upload-artifact@v4
|
||||||
|
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0'
|
||||||
|
with:
|
||||||
|
name: generated-models-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}
|
||||||
|
path: ./*.wav
|
||||||
|
|
||||||
|
- name: Release exported onnx models
|
||||||
|
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push'
|
||||||
|
uses: svenstaro/upload-release-action@v2
|
||||||
|
with:
|
||||||
|
file_glob: true
|
||||||
|
overwrite: true
|
||||||
|
file: vits-icefall-*.tar.bz2
|
||||||
|
repo_name: k2-fsa/sherpa-onnx
|
||||||
|
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
|
||||||
|
tag: tts-models
|
||||||
|
|
9
.github/workflows/run-docker-image.yml
vendored
9
.github/workflows/run-docker-image.yml
vendored
@ -14,13 +14,20 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
image: ["torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
|
image: ["torch2.3.1-cuda12.1", "torch2.3.1-cuda11.8", "torch2.2.2-cuda12.1", "torch2.2.2-cuda11.8", "torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
|
||||||
steps:
|
steps:
|
||||||
# refer to https://github.com/actions/checkout
|
# refer to https://github.com/actions/checkout
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Free space
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
df -h
|
||||||
|
rm -rf /opt/hostedtoolcache
|
||||||
|
df -h
|
||||||
|
|
||||||
- name: Run the build process with Docker
|
- name: Run the build process with Docker
|
||||||
uses: addnab/docker-run-action@v3
|
uses: addnab/docker-run-action@v3
|
||||||
with:
|
with:
|
||||||
|
8
.github/workflows/style_check.yml
vendored
8
.github/workflows/style_check.yml
vendored
@ -49,7 +49,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install Python dependencies
|
- name: Install Python dependencies
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0
|
python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0 isort==5.10.1
|
||||||
# Click issue fixed in https://github.com/psf/black/pull/2966
|
# Click issue fixed in https://github.com/psf/black/pull/2966
|
||||||
|
|
||||||
- name: Run flake8
|
- name: Run flake8
|
||||||
@ -67,3 +67,9 @@ jobs:
|
|||||||
working-directory: ${{github.workspace}}
|
working-directory: ${{github.workspace}}
|
||||||
run: |
|
run: |
|
||||||
black --check --diff .
|
black --check --diff .
|
||||||
|
|
||||||
|
- name: Run isort
|
||||||
|
shell: bash
|
||||||
|
working-directory: ${{github.workspace}}
|
||||||
|
run: |
|
||||||
|
isort --check --diff .
|
||||||
|
3
.github/workflows/yesno.yml
vendored
3
.github/workflows/yesno.yml
vendored
@ -59,4 +59,7 @@ jobs:
|
|||||||
cd /icefall
|
cd /icefall
|
||||||
git config --global --add safe.directory /icefall
|
git config --global --add safe.directory /icefall
|
||||||
|
|
||||||
|
python3 -m torch.utils.collect_env
|
||||||
|
python3 -m k2.version
|
||||||
|
|
||||||
.github/scripts/yesno/ASR/run.sh
|
.github/scripts/yesno/ASR/run.sh
|
||||||
|
@ -26,7 +26,7 @@ repos:
|
|||||||
# E121,E123,E126,E226,E24,E704,W503,W504
|
# E121,E123,E126,E226,E24,E704,W503,W504
|
||||||
|
|
||||||
- repo: https://github.com/pycqa/isort
|
- repo: https://github.com/pycqa/isort
|
||||||
rev: 5.11.5
|
rev: 5.12.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
args: ["--profile=black"]
|
args: ["--profile=black"]
|
||||||
|
@ -74,6 +74,9 @@ The [LibriSpeech][librispeech] recipe supports the most comprehensive set of mod
|
|||||||
- LSTM-based Predictor
|
- LSTM-based Predictor
|
||||||
- [Stateless Predictor](https://research.google/pubs/rnn-transducer-with-stateless-prediction-network/)
|
- [Stateless Predictor](https://research.google/pubs/rnn-transducer-with-stateless-prediction-network/)
|
||||||
|
|
||||||
|
#### Whisper
|
||||||
|
- [OpenAi Whisper](https://arxiv.org/abs/2212.04356) (We support fine-tuning on AiShell-1.)
|
||||||
|
|
||||||
If you are willing to contribute to icefall, please refer to [contributing](https://icefall.readthedocs.io/en/latest/contributing/index.html) for more details.
|
If you are willing to contribute to icefall, please refer to [contributing](https://icefall.readthedocs.io/en/latest/contributing/index.html) for more details.
|
||||||
|
|
||||||
We would like to highlight the performance of some of the recipes here.
|
We would like to highlight the performance of some of the recipes here.
|
||||||
|
@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
|
|||||||
ARG DEBIAN_FRONTEND=noninteractive
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
# python 3.7
|
# python 3.7
|
||||||
ARG K2_VERSION="1.24.4.dev20230725+cuda11.3.torch1.12.1"
|
ARG K2_VERSION="1.24.4.dev20240223+cuda11.3.torch1.12.1"
|
||||||
ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.3.torch1.12.1"
|
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.3.torch1.12.1"
|
||||||
ARG TORCHAUDIO_VERSION="0.12.1+cu113"
|
ARG TORCHAUDIO_VERSION="0.12.1+cu113"
|
||||||
|
|
||||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||||
@ -55,6 +55,8 @@ RUN pip install --no-cache-dir \
|
|||||||
onnx \
|
onnx \
|
||||||
onnxruntime \
|
onnxruntime \
|
||||||
onnxmltools \
|
onnxmltools \
|
||||||
|
onnxoptimizer \
|
||||||
|
onnxsim \
|
||||||
multi_quantization \
|
multi_quantization \
|
||||||
typeguard \
|
typeguard \
|
||||||
numpy \
|
numpy \
|
||||||
|
@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
|
|||||||
ARG DEBIAN_FRONTEND=noninteractive
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
# python 3.9
|
# python 3.9
|
||||||
ARG K2_VERSION="1.24.4.dev20231021+cuda11.6.torch1.13.0"
|
ARG K2_VERSION="1.24.4.dev20240223+cuda11.6.torch1.13.0"
|
||||||
ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.6.torch1.13.0"
|
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.6.torch1.13.0"
|
||||||
ARG TORCHAUDIO_VERSION="0.13.0+cu116"
|
ARG TORCHAUDIO_VERSION="0.13.0+cu116"
|
||||||
|
|
||||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||||
@ -55,6 +55,8 @@ RUN pip install --no-cache-dir \
|
|||||||
onnx \
|
onnx \
|
||||||
onnxruntime \
|
onnxruntime \
|
||||||
onnxmltools \
|
onnxmltools \
|
||||||
|
onnxoptimizer \
|
||||||
|
onnxsim \
|
||||||
multi_quantization \
|
multi_quantization \
|
||||||
typeguard \
|
typeguard \
|
||||||
numpy \
|
numpy \
|
||||||
|
@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
|
|||||||
ARG DEBIAN_FRONTEND=noninteractive
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
# python 3.7
|
# python 3.7
|
||||||
ARG K2_VERSION="1.24.3.dev20230726+cuda10.2.torch1.9.0"
|
ARG K2_VERSION="1.24.4.dev20240223+cuda10.2.torch1.9.0"
|
||||||
ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda10.2.torch1.9.0"
|
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda10.2.torch1.9.0"
|
||||||
ARG TORCHAUDIO_VERSION="0.9.0"
|
ARG TORCHAUDIO_VERSION="0.9.0"
|
||||||
|
|
||||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||||
@ -69,6 +69,8 @@ RUN pip uninstall -y tqdm && \
|
|||||||
onnx \
|
onnx \
|
||||||
onnxruntime \
|
onnxruntime \
|
||||||
onnxmltools \
|
onnxmltools \
|
||||||
|
onnxoptimizer \
|
||||||
|
onnxsim \
|
||||||
multi_quantization \
|
multi_quantization \
|
||||||
typeguard \
|
typeguard \
|
||||||
numpy \
|
numpy \
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
FROM pytorch/pytorch:2.0.0-cuda11.7-cudnn8-devel
|
FROM pytorch/pytorch:2.0.0-cuda11.7-cudnn8-devel
|
||||||
|
# python 3.10
|
||||||
|
|
||||||
ENV LC_ALL C.UTF-8
|
ENV LC_ALL C.UTF-8
|
||||||
|
|
||||||
ARG DEBIAN_FRONTEND=noninteractive
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
# python 3.10
|
# python 3.10
|
||||||
ARG K2_VERSION="1.24.4.dev20231021+cuda11.7.torch2.0.0"
|
ARG K2_VERSION="1.24.4.dev20240223+cuda11.7.torch2.0.0"
|
||||||
ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.7.torch2.0.0"
|
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.7.torch2.0.0"
|
||||||
ARG TORCHAUDIO_VERSION="2.0.0+cu117"
|
ARG TORCHAUDIO_VERSION="2.0.0+cu117"
|
||||||
|
|
||||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||||
@ -55,6 +56,8 @@ RUN pip install --no-cache-dir \
|
|||||||
onnx \
|
onnx \
|
||||||
onnxruntime \
|
onnxruntime \
|
||||||
onnxmltools \
|
onnxmltools \
|
||||||
|
onnxoptimizer \
|
||||||
|
onnxsim \
|
||||||
multi_quantization \
|
multi_quantization \
|
||||||
typeguard \
|
typeguard \
|
||||||
numpy \
|
numpy \
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
FROM pytorch/pytorch:2.1.0-cuda11.8-cudnn8-devel
|
FROM pytorch/pytorch:2.1.0-cuda11.8-cudnn8-devel
|
||||||
|
# python 3.10
|
||||||
|
|
||||||
ENV LC_ALL C.UTF-8
|
ENV LC_ALL C.UTF-8
|
||||||
|
|
||||||
ARG DEBIAN_FRONTEND=noninteractive
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
# python 3.10
|
# python 3.10
|
||||||
ARG K2_VERSION="1.24.4.dev20231021+cuda11.8.torch2.1.0"
|
ARG K2_VERSION="1.24.4.dev20240223+cuda11.8.torch2.1.0"
|
||||||
ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.8.torch2.1.0"
|
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.8.torch2.1.0"
|
||||||
ARG TORCHAUDIO_VERSION="2.1.0+cu118"
|
ARG TORCHAUDIO_VERSION="2.1.0+cu118"
|
||||||
|
|
||||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||||
@ -55,6 +56,8 @@ RUN pip install --no-cache-dir \
|
|||||||
onnx \
|
onnx \
|
||||||
onnxruntime \
|
onnxruntime \
|
||||||
onnxmltools \
|
onnxmltools \
|
||||||
|
onnxoptimizer \
|
||||||
|
onnxsim \
|
||||||
multi_quantization \
|
multi_quantization \
|
||||||
typeguard \
|
typeguard \
|
||||||
numpy \
|
numpy \
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-devel
|
FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-devel
|
||||||
|
# python 3.10
|
||||||
|
|
||||||
ENV LC_ALL C.UTF-8
|
ENV LC_ALL C.UTF-8
|
||||||
|
|
||||||
ARG DEBIAN_FRONTEND=noninteractive
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
# python 3.10
|
# python 3.10
|
||||||
ARG K2_VERSION="1.24.4.dev20231021+cuda12.1.torch2.1.0"
|
ARG K2_VERSION="1.24.4.dev20240223+cuda12.1.torch2.1.0"
|
||||||
ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda12.1.torch2.1.0"
|
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda12.1.torch2.1.0"
|
||||||
ARG TORCHAUDIO_VERSION="2.1.0+cu121"
|
ARG TORCHAUDIO_VERSION="2.1.0+cu121"
|
||||||
|
|
||||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||||
@ -55,6 +56,8 @@ RUN pip install --no-cache-dir \
|
|||||||
onnx \
|
onnx \
|
||||||
onnxruntime \
|
onnxruntime \
|
||||||
onnxmltools \
|
onnxmltools \
|
||||||
|
onnxoptimizer \
|
||||||
|
onnxsim \
|
||||||
multi_quantization \
|
multi_quantization \
|
||||||
typeguard \
|
typeguard \
|
||||||
numpy \
|
numpy \
|
||||||
|
73
docker/torch2.2.0-cuda11.8.dockerfile
Normal file
73
docker/torch2.2.0-cuda11.8.dockerfile
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
FROM pytorch/pytorch:2.2.0-cuda11.8-cudnn8-devel
|
||||||
|
# python 3.10
|
||||||
|
|
||||||
|
ENV LC_ALL C.UTF-8
|
||||||
|
|
||||||
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
# python 3.10
|
||||||
|
ARG K2_VERSION="1.24.4.dev20240223+cuda11.8.torch2.2.0"
|
||||||
|
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.8.torch2.2.0"
|
||||||
|
ARG TORCHAUDIO_VERSION="2.2.0+cu118"
|
||||||
|
|
||||||
|
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||||
|
LABEL k2_version=${K2_VERSION}
|
||||||
|
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
|
||||||
|
LABEL github_repo="https://github.com/k2-fsa/icefall"
|
||||||
|
|
||||||
|
RUN apt-get update && \
|
||||||
|
apt-get install -y --no-install-recommends \
|
||||||
|
curl \
|
||||||
|
vim \
|
||||||
|
libssl-dev \
|
||||||
|
autoconf \
|
||||||
|
automake \
|
||||||
|
bzip2 \
|
||||||
|
ca-certificates \
|
||||||
|
ffmpeg \
|
||||||
|
g++ \
|
||||||
|
gfortran \
|
||||||
|
git \
|
||||||
|
libtool \
|
||||||
|
make \
|
||||||
|
patch \
|
||||||
|
sox \
|
||||||
|
subversion \
|
||||||
|
unzip \
|
||||||
|
valgrind \
|
||||||
|
wget \
|
||||||
|
zlib1g-dev \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
RUN pip install --no-cache-dir \
|
||||||
|
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
|
||||||
|
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
|
||||||
|
git+https://github.com/lhotse-speech/lhotse \
|
||||||
|
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
|
||||||
|
kaldi_native_io \
|
||||||
|
kaldialign \
|
||||||
|
kaldifst \
|
||||||
|
kaldilm \
|
||||||
|
sentencepiece>=0.1.96 \
|
||||||
|
tensorboard \
|
||||||
|
typeguard \
|
||||||
|
dill \
|
||||||
|
onnx \
|
||||||
|
onnxruntime \
|
||||||
|
onnxmltools \
|
||||||
|
onnxoptimizer \
|
||||||
|
onnxsim \
|
||||||
|
multi_quantization \
|
||||||
|
typeguard \
|
||||||
|
numpy \
|
||||||
|
pytest \
|
||||||
|
graphviz
|
||||||
|
|
||||||
|
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
|
||||||
|
cd /workspace/icefall && \
|
||||||
|
pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
|
||||||
|
|
||||||
|
WORKDIR /workspace/icefall
|
73
docker/torch2.2.0-cuda12.1.dockerfile
Normal file
73
docker/torch2.2.0-cuda12.1.dockerfile
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-devel
|
||||||
|
# python 3.10
|
||||||
|
|
||||||
|
ENV LC_ALL C.UTF-8
|
||||||
|
|
||||||
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
# python 3.10
|
||||||
|
ARG K2_VERSION="1.24.4.dev20240223+cuda12.1.torch2.2.0"
|
||||||
|
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda12.1.torch2.2.0"
|
||||||
|
ARG TORCHAUDIO_VERSION="2.2.0+cu121"
|
||||||
|
|
||||||
|
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||||
|
LABEL k2_version=${K2_VERSION}
|
||||||
|
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
|
||||||
|
LABEL github_repo="https://github.com/k2-fsa/icefall"
|
||||||
|
|
||||||
|
RUN apt-get update && \
|
||||||
|
apt-get install -y --no-install-recommends \
|
||||||
|
curl \
|
||||||
|
vim \
|
||||||
|
libssl-dev \
|
||||||
|
autoconf \
|
||||||
|
automake \
|
||||||
|
bzip2 \
|
||||||
|
ca-certificates \
|
||||||
|
ffmpeg \
|
||||||
|
g++ \
|
||||||
|
gfortran \
|
||||||
|
git \
|
||||||
|
libtool \
|
||||||
|
make \
|
||||||
|
patch \
|
||||||
|
sox \
|
||||||
|
subversion \
|
||||||
|
unzip \
|
||||||
|
valgrind \
|
||||||
|
wget \
|
||||||
|
zlib1g-dev \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
RUN pip install --no-cache-dir \
|
||||||
|
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
|
||||||
|
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
|
||||||
|
git+https://github.com/lhotse-speech/lhotse \
|
||||||
|
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
|
||||||
|
kaldi_native_io \
|
||||||
|
kaldialign \
|
||||||
|
kaldifst \
|
||||||
|
kaldilm \
|
||||||
|
sentencepiece>=0.1.96 \
|
||||||
|
tensorboard \
|
||||||
|
typeguard \
|
||||||
|
dill \
|
||||||
|
onnx \
|
||||||
|
onnxruntime \
|
||||||
|
onnxmltools \
|
||||||
|
onnxoptimizer \
|
||||||
|
onnxsim \
|
||||||
|
multi_quantization \
|
||||||
|
typeguard \
|
||||||
|
numpy \
|
||||||
|
pytest \
|
||||||
|
graphviz
|
||||||
|
|
||||||
|
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
|
||||||
|
cd /workspace/icefall && \
|
||||||
|
pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
|
||||||
|
|
||||||
|
WORKDIR /workspace/icefall
|
73
docker/torch2.2.1-cuda11.8.dockerfile
Normal file
73
docker/torch2.2.1-cuda11.8.dockerfile
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
FROM pytorch/pytorch:2.2.1-cuda11.8-cudnn8-devel
|
||||||
|
# python 3.10
|
||||||
|
|
||||||
|
ENV LC_ALL C.UTF-8
|
||||||
|
|
||||||
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
# python 3.10
|
||||||
|
ARG K2_VERSION="1.24.4.dev20240223+cuda11.8.torch2.2.1"
|
||||||
|
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.8.torch2.2.1"
|
||||||
|
ARG TORCHAUDIO_VERSION="2.2.1+cu118"
|
||||||
|
|
||||||
|
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||||
|
LABEL k2_version=${K2_VERSION}
|
||||||
|
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
|
||||||
|
LABEL github_repo="https://github.com/k2-fsa/icefall"
|
||||||
|
|
||||||
|
RUN apt-get update && \
|
||||||
|
apt-get install -y --no-install-recommends \
|
||||||
|
curl \
|
||||||
|
vim \
|
||||||
|
libssl-dev \
|
||||||
|
autoconf \
|
||||||
|
automake \
|
||||||
|
bzip2 \
|
||||||
|
ca-certificates \
|
||||||
|
ffmpeg \
|
||||||
|
g++ \
|
||||||
|
gfortran \
|
||||||
|
git \
|
||||||
|
libtool \
|
||||||
|
make \
|
||||||
|
patch \
|
||||||
|
sox \
|
||||||
|
subversion \
|
||||||
|
unzip \
|
||||||
|
valgrind \
|
||||||
|
wget \
|
||||||
|
zlib1g-dev \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
RUN pip install --no-cache-dir \
|
||||||
|
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
|
||||||
|
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
|
||||||
|
git+https://github.com/lhotse-speech/lhotse \
|
||||||
|
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
|
||||||
|
kaldi_native_io \
|
||||||
|
kaldialign \
|
||||||
|
kaldifst \
|
||||||
|
kaldilm \
|
||||||
|
sentencepiece>=0.1.96 \
|
||||||
|
tensorboard \
|
||||||
|
typeguard \
|
||||||
|
dill \
|
||||||
|
onnx \
|
||||||
|
onnxruntime \
|
||||||
|
onnxmltools \
|
||||||
|
onnxoptimizer \
|
||||||
|
onnxsim \
|
||||||
|
multi_quantization \
|
||||||
|
typeguard \
|
||||||
|
numpy \
|
||||||
|
pytest \
|
||||||
|
graphviz
|
||||||
|
|
||||||
|
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
|
||||||
|
cd /workspace/icefall && \
|
||||||
|
pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
|
||||||
|
|
||||||
|
WORKDIR /workspace/icefall
|
73
docker/torch2.2.1-cuda12.1.dockerfile
Normal file
73
docker/torch2.2.1-cuda12.1.dockerfile
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
FROM pytorch/pytorch:2.2.1-cuda12.1-cudnn8-devel
|
||||||
|
# python 3.10
|
||||||
|
|
||||||
|
ENV LC_ALL C.UTF-8
|
||||||
|
|
||||||
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
# python 3.10
|
||||||
|
ARG K2_VERSION="1.24.4.dev20240223+cuda12.1.torch2.2.1"
|
||||||
|
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda12.1.torch2.2.1"
|
||||||
|
ARG TORCHAUDIO_VERSION="2.2.1+cu121"
|
||||||
|
|
||||||
|
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||||
|
LABEL k2_version=${K2_VERSION}
|
||||||
|
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
|
||||||
|
LABEL github_repo="https://github.com/k2-fsa/icefall"
|
||||||
|
|
||||||
|
RUN apt-get update && \
|
||||||
|
apt-get install -y --no-install-recommends \
|
||||||
|
curl \
|
||||||
|
vim \
|
||||||
|
libssl-dev \
|
||||||
|
autoconf \
|
||||||
|
automake \
|
||||||
|
bzip2 \
|
||||||
|
ca-certificates \
|
||||||
|
ffmpeg \
|
||||||
|
g++ \
|
||||||
|
gfortran \
|
||||||
|
git \
|
||||||
|
libtool \
|
||||||
|
make \
|
||||||
|
patch \
|
||||||
|
sox \
|
||||||
|
subversion \
|
||||||
|
unzip \
|
||||||
|
valgrind \
|
||||||
|
wget \
|
||||||
|
zlib1g-dev \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
RUN pip install --no-cache-dir \
|
||||||
|
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
|
||||||
|
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
|
||||||
|
git+https://github.com/lhotse-speech/lhotse \
|
||||||
|
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
|
||||||
|
kaldi_native_io \
|
||||||
|
kaldialign \
|
||||||
|
kaldifst \
|
||||||
|
kaldilm \
|
||||||
|
sentencepiece>=0.1.96 \
|
||||||
|
tensorboard \
|
||||||
|
typeguard \
|
||||||
|
dill \
|
||||||
|
onnx \
|
||||||
|
onnxruntime \
|
||||||
|
onnxmltools \
|
||||||
|
onnxoptimizer \
|
||||||
|
onnxsim \
|
||||||
|
multi_quantization \
|
||||||
|
typeguard \
|
||||||
|
numpy \
|
||||||
|
pytest \
|
||||||
|
graphviz
|
||||||
|
|
||||||
|
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
|
||||||
|
cd /workspace/icefall && \
|
||||||
|
pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
|
||||||
|
|
||||||
|
WORKDIR /workspace/icefall
|
73
docker/torch2.2.2-cuda11.8.dockerfile
Normal file
73
docker/torch2.2.2-cuda11.8.dockerfile
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
FROM pytorch/pytorch:2.2.2-cuda11.8-cudnn8-devel
|
||||||
|
# python 3.10
|
||||||
|
|
||||||
|
ENV LC_ALL C.UTF-8
|
||||||
|
|
||||||
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
# python 3.10
|
||||||
|
ARG K2_VERSION="1.24.4.dev20240328+cuda11.8.torch2.2.2"
|
||||||
|
ARG KALDIFEAT_VERSION="1.25.4.dev20240329+cuda11.8.torch2.2.2"
|
||||||
|
ARG TORCHAUDIO_VERSION="2.2.2+cu118"
|
||||||
|
|
||||||
|
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||||
|
LABEL k2_version=${K2_VERSION}
|
||||||
|
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
|
||||||
|
LABEL github_repo="https://github.com/k2-fsa/icefall"
|
||||||
|
|
||||||
|
RUN apt-get update && \
|
||||||
|
apt-get install -y --no-install-recommends \
|
||||||
|
curl \
|
||||||
|
vim \
|
||||||
|
libssl-dev \
|
||||||
|
autoconf \
|
||||||
|
automake \
|
||||||
|
bzip2 \
|
||||||
|
ca-certificates \
|
||||||
|
ffmpeg \
|
||||||
|
g++ \
|
||||||
|
gfortran \
|
||||||
|
git \
|
||||||
|
libtool \
|
||||||
|
make \
|
||||||
|
patch \
|
||||||
|
sox \
|
||||||
|
subversion \
|
||||||
|
unzip \
|
||||||
|
valgrind \
|
||||||
|
wget \
|
||||||
|
zlib1g-dev \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
RUN pip install --no-cache-dir \
|
||||||
|
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
|
||||||
|
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
|
||||||
|
git+https://github.com/lhotse-speech/lhotse \
|
||||||
|
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
|
||||||
|
kaldi_native_io \
|
||||||
|
kaldialign \
|
||||||
|
kaldifst \
|
||||||
|
kaldilm \
|
||||||
|
sentencepiece>=0.1.96 \
|
||||||
|
tensorboard \
|
||||||
|
typeguard \
|
||||||
|
dill \
|
||||||
|
onnx \
|
||||||
|
onnxruntime \
|
||||||
|
onnxmltools \
|
||||||
|
onnxoptimizer \
|
||||||
|
onnxsim \
|
||||||
|
multi_quantization \
|
||||||
|
typeguard \
|
||||||
|
numpy \
|
||||||
|
pytest \
|
||||||
|
graphviz
|
||||||
|
|
||||||
|
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
|
||||||
|
cd /workspace/icefall && \
|
||||||
|
pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
|
||||||
|
|
||||||
|
WORKDIR /workspace/icefall
|
73
docker/torch2.2.2-cuda12.1.dockerfile
Normal file
73
docker/torch2.2.2-cuda12.1.dockerfile
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
FROM pytorch/pytorch:2.2.2-cuda12.1-cudnn8-devel
|
||||||
|
# python 3.10
|
||||||
|
|
||||||
|
ENV LC_ALL C.UTF-8
|
||||||
|
|
||||||
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
# python 3.10
|
||||||
|
ARG K2_VERSION="1.24.4.dev20240328+cuda12.1.torch2.2.2"
|
||||||
|
ARG KALDIFEAT_VERSION="1.25.4.dev20240329+cuda12.1.torch2.2.2"
|
||||||
|
ARG TORCHAUDIO_VERSION="2.2.2+cu121"
|
||||||
|
|
||||||
|
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||||
|
LABEL k2_version=${K2_VERSION}
|
||||||
|
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
|
||||||
|
LABEL github_repo="https://github.com/k2-fsa/icefall"
|
||||||
|
|
||||||
|
RUN apt-get update && \
|
||||||
|
apt-get install -y --no-install-recommends \
|
||||||
|
curl \
|
||||||
|
vim \
|
||||||
|
libssl-dev \
|
||||||
|
autoconf \
|
||||||
|
automake \
|
||||||
|
bzip2 \
|
||||||
|
ca-certificates \
|
||||||
|
ffmpeg \
|
||||||
|
g++ \
|
||||||
|
gfortran \
|
||||||
|
git \
|
||||||
|
libtool \
|
||||||
|
make \
|
||||||
|
patch \
|
||||||
|
sox \
|
||||||
|
subversion \
|
||||||
|
unzip \
|
||||||
|
valgrind \
|
||||||
|
wget \
|
||||||
|
zlib1g-dev \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
RUN pip install --no-cache-dir \
|
||||||
|
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
|
||||||
|
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
|
||||||
|
git+https://github.com/lhotse-speech/lhotse \
|
||||||
|
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
|
||||||
|
kaldi_native_io \
|
||||||
|
kaldialign \
|
||||||
|
kaldifst \
|
||||||
|
kaldilm \
|
||||||
|
sentencepiece>=0.1.96 \
|
||||||
|
tensorboard \
|
||||||
|
typeguard \
|
||||||
|
dill \
|
||||||
|
onnx \
|
||||||
|
onnxruntime \
|
||||||
|
onnxmltools \
|
||||||
|
onnxoptimizer \
|
||||||
|
onnxsim \
|
||||||
|
multi_quantization \
|
||||||
|
typeguard \
|
||||||
|
numpy \
|
||||||
|
pytest \
|
||||||
|
graphviz
|
||||||
|
|
||||||
|
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
|
||||||
|
cd /workspace/icefall && \
|
||||||
|
pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
|
||||||
|
|
||||||
|
WORKDIR /workspace/icefall
|
73
docker/torch2.3.1-cuda11.8.dockerfile
Normal file
73
docker/torch2.3.1-cuda11.8.dockerfile
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
FROM pytorch/pytorch:2.3.1-cuda11.8-cudnn8-devel
|
||||||
|
# python 3.10
|
||||||
|
|
||||||
|
ENV LC_ALL C.UTF-8
|
||||||
|
|
||||||
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
# python 3.10
|
||||||
|
ARG K2_VERSION="1.24.4.dev20240606+cuda11.8.torch2.3.1"
|
||||||
|
ARG KALDIFEAT_VERSION="1.25.4.dev20240606+cuda11.8.torch2.3.1"
|
||||||
|
ARG TORCHAUDIO_VERSION="2.3.1+cu118"
|
||||||
|
|
||||||
|
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||||
|
LABEL k2_version=${K2_VERSION}
|
||||||
|
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
|
||||||
|
LABEL github_repo="https://github.com/k2-fsa/icefall"
|
||||||
|
|
||||||
|
RUN apt-get update && \
|
||||||
|
apt-get install -y --no-install-recommends \
|
||||||
|
curl \
|
||||||
|
vim \
|
||||||
|
libssl-dev \
|
||||||
|
autoconf \
|
||||||
|
automake \
|
||||||
|
bzip2 \
|
||||||
|
ca-certificates \
|
||||||
|
ffmpeg \
|
||||||
|
g++ \
|
||||||
|
gfortran \
|
||||||
|
git \
|
||||||
|
libtool \
|
||||||
|
make \
|
||||||
|
patch \
|
||||||
|
sox \
|
||||||
|
subversion \
|
||||||
|
unzip \
|
||||||
|
valgrind \
|
||||||
|
wget \
|
||||||
|
zlib1g-dev \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
RUN pip install --no-cache-dir \
|
||||||
|
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
|
||||||
|
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
|
||||||
|
git+https://github.com/lhotse-speech/lhotse \
|
||||||
|
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
|
||||||
|
kaldi_native_io \
|
||||||
|
kaldialign \
|
||||||
|
kaldifst \
|
||||||
|
kaldilm \
|
||||||
|
sentencepiece>=0.1.96 \
|
||||||
|
tensorboard \
|
||||||
|
typeguard \
|
||||||
|
dill \
|
||||||
|
onnx \
|
||||||
|
onnxruntime \
|
||||||
|
onnxmltools \
|
||||||
|
onnxoptimizer \
|
||||||
|
onnxsim \
|
||||||
|
multi_quantization \
|
||||||
|
typeguard \
|
||||||
|
numpy \
|
||||||
|
pytest \
|
||||||
|
graphviz
|
||||||
|
|
||||||
|
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
|
||||||
|
cd /workspace/icefall && \
|
||||||
|
pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
|
||||||
|
|
||||||
|
WORKDIR /workspace/icefall
|
73
docker/torch2.3.1-cuda12.1.dockerfile
Normal file
73
docker/torch2.3.1-cuda12.1.dockerfile
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-devel
|
||||||
|
# python 3.10
|
||||||
|
|
||||||
|
ENV LC_ALL C.UTF-8
|
||||||
|
|
||||||
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
# python 3.10
|
||||||
|
ARG K2_VERSION="1.24.4.dev20240606+cuda12.1.torch2.3.1"
|
||||||
|
ARG KALDIFEAT_VERSION="1.25.4.dev20240606+cuda12.1.torch2.3.1"
|
||||||
|
ARG TORCHAUDIO_VERSION="2.3.1+cu121"
|
||||||
|
|
||||||
|
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||||
|
LABEL k2_version=${K2_VERSION}
|
||||||
|
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
|
||||||
|
LABEL github_repo="https://github.com/k2-fsa/icefall"
|
||||||
|
|
||||||
|
RUN apt-get update && \
|
||||||
|
apt-get install -y --no-install-recommends \
|
||||||
|
curl \
|
||||||
|
vim \
|
||||||
|
libssl-dev \
|
||||||
|
autoconf \
|
||||||
|
automake \
|
||||||
|
bzip2 \
|
||||||
|
ca-certificates \
|
||||||
|
ffmpeg \
|
||||||
|
g++ \
|
||||||
|
gfortran \
|
||||||
|
git \
|
||||||
|
libtool \
|
||||||
|
make \
|
||||||
|
patch \
|
||||||
|
sox \
|
||||||
|
subversion \
|
||||||
|
unzip \
|
||||||
|
valgrind \
|
||||||
|
wget \
|
||||||
|
zlib1g-dev \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
RUN pip install --no-cache-dir \
|
||||||
|
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
|
||||||
|
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
|
||||||
|
git+https://github.com/lhotse-speech/lhotse \
|
||||||
|
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
|
||||||
|
kaldi_native_io \
|
||||||
|
kaldialign \
|
||||||
|
kaldifst \
|
||||||
|
kaldilm \
|
||||||
|
sentencepiece>=0.1.96 \
|
||||||
|
tensorboard \
|
||||||
|
typeguard \
|
||||||
|
dill \
|
||||||
|
onnx \
|
||||||
|
onnxruntime \
|
||||||
|
onnxmltools \
|
||||||
|
onnxoptimizer \
|
||||||
|
onnxsim \
|
||||||
|
multi_quantization \
|
||||||
|
typeguard \
|
||||||
|
numpy \
|
||||||
|
pytest \
|
||||||
|
graphviz
|
||||||
|
|
||||||
|
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
|
||||||
|
cd /workspace/icefall && \
|
||||||
|
pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
|
||||||
|
|
||||||
|
WORKDIR /workspace/icefall
|
@ -30,7 +30,7 @@ of langugae model integration.
|
|||||||
First, let's have a look at some background information. As the predecessor of LODR, Density Ratio (DR) is first proposed `here <https://arxiv.org/abs/2002.11268>`_
|
First, let's have a look at some background information. As the predecessor of LODR, Density Ratio (DR) is first proposed `here <https://arxiv.org/abs/2002.11268>`_
|
||||||
to address the language information mismatch between the training
|
to address the language information mismatch between the training
|
||||||
corpus (source domain) and the testing corpus (target domain). Assuming that the source domain and the test domain
|
corpus (source domain) and the testing corpus (target domain). Assuming that the source domain and the test domain
|
||||||
are acoustically similar, DR derives the following formular for decoding with Bayes' theorem:
|
are acoustically similar, DR derives the following formula for decoding with Bayes' theorem:
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
|
|
||||||
@ -41,7 +41,7 @@ are acoustically similar, DR derives the following formular for decoding with Ba
|
|||||||
|
|
||||||
|
|
||||||
where :math:`\lambda_1` and :math:`\lambda_2` are the weights of LM scores for target domain and source domain respectively.
|
where :math:`\lambda_1` and :math:`\lambda_2` are the weights of LM scores for target domain and source domain respectively.
|
||||||
Here, the source domain LM is trained on the training corpus. The only difference in the above formular compared to
|
Here, the source domain LM is trained on the training corpus. The only difference in the above formula compared to
|
||||||
shallow fusion is the subtraction of the source domain LM.
|
shallow fusion is the subtraction of the source domain LM.
|
||||||
|
|
||||||
Some works treat the predictor and the joiner of the neural transducer as its internal LM. However, the LM is
|
Some works treat the predictor and the joiner of the neural transducer as its internal LM. However, the LM is
|
||||||
@ -58,7 +58,7 @@ during decoding for transducer model:
|
|||||||
|
|
||||||
In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Compared to DR,
|
In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Compared to DR,
|
||||||
the only difference lies in the choice of source domain LM. According to the original `paper <https://arxiv.org/abs/2203.16776>`_,
|
the only difference lies in the choice of source domain LM. According to the original `paper <https://arxiv.org/abs/2203.16776>`_,
|
||||||
LODR achieves similar performance compared DR in both intra-domain and cross-domain settings.
|
LODR achieves similar performance compared to DR in both intra-domain and cross-domain settings.
|
||||||
As a bi-gram is much faster to evaluate, LODR is usually much faster.
|
As a bi-gram is much faster to evaluate, LODR is usually much faster.
|
||||||
|
|
||||||
Now, we will show you how to use LODR in ``icefall``.
|
Now, we will show you how to use LODR in ``icefall``.
|
||||||
|
@ -9,9 +9,9 @@ to improve the word-error-rate of a transducer model.
|
|||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
This tutorial is based on the recipe
|
This tutorial is based on the recipe
|
||||||
`pruned_transducer_stateless7_streaming <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming>`_,
|
`pruned_transducer_stateless7_streaming <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming>`_,
|
||||||
which is a streaming transducer model trained on `LibriSpeech`_.
|
which is a streaming transducer model trained on `LibriSpeech`_.
|
||||||
However, you can easily apply shallow fusion to other recipes.
|
However, you can easily apply shallow fusion to other recipes.
|
||||||
If you encounter any problems, please open an issue here `icefall <https://github.com/k2-fsa/icefall/issues>`_.
|
If you encounter any problems, please open an issue here `icefall <https://github.com/k2-fsa/icefall/issues>`_.
|
||||||
|
|
||||||
@ -69,11 +69,11 @@ Training a language model usually takes a long time, we can download a pre-train
|
|||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
$ # download the external LM
|
$ # download the external LM
|
||||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
|
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
|
||||||
$ # create a symbolic link so that the checkpoint can be loaded
|
$ # create a symbolic link so that the checkpoint can be loaded
|
||||||
$ pushd icefall-librispeech-rnn-lm/exp
|
$ pushd icefall-librispeech-rnn-lm/exp
|
||||||
$ git lfs pull --include "pretrained.pt"
|
$ git lfs pull --include "pretrained.pt"
|
||||||
$ ln -s pretrained.pt epoch-99.pt
|
$ ln -s pretrained.pt epoch-99.pt
|
||||||
$ popd
|
$ popd
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
@ -85,7 +85,7 @@ Training a language model usually takes a long time, we can download a pre-train
|
|||||||
To use shallow fusion for decoding, we can execute the following command:
|
To use shallow fusion for decoding, we can execute the following command:
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||||
$ lm_dir=./icefall-librispeech-rnn-lm/exp
|
$ lm_dir=./icefall-librispeech-rnn-lm/exp
|
||||||
$ lm_scale=0.29
|
$ lm_scale=0.29
|
||||||
@ -133,16 +133,16 @@ The decoding result obtained with the above command are shown below.
|
|||||||
$ For test-other, WER of different settings are:
|
$ For test-other, WER of different settings are:
|
||||||
$ beam_size_4 7.08 best for test-other
|
$ beam_size_4 7.08 best for test-other
|
||||||
|
|
||||||
The improvement of shallow fusion is very obvious! The relative WER reduction on test-other is around 10.5%.
|
The improvement of shallow fusion is very obvious! The relative WER reduction on test-other is around 10.5%.
|
||||||
A few parameters can be tuned to further boost the performance of shallow fusion:
|
A few parameters can be tuned to further boost the performance of shallow fusion:
|
||||||
|
|
||||||
- ``--lm-scale``
|
- ``--lm-scale``
|
||||||
|
|
||||||
Controls the scale of the LM. If too small, the external language model may not be fully utilized; if too large,
|
Controls the scale of the LM. If too small, the external language model may not be fully utilized; if too large,
|
||||||
the LM score may dominant during decoding, leading to bad WER. A typical value of this is around 0.3.
|
the LM score might be dominant during decoding, leading to bad WER. A typical value of this is around 0.3.
|
||||||
|
|
||||||
|
- ``--beam-size``
|
||||||
|
|
||||||
- ``--beam-size``
|
|
||||||
|
|
||||||
The number of active paths in the search beam. It controls the trade-off between decoding efficiency and accuracy.
|
The number of active paths in the search beam. It controls the trade-off between decoding efficiency and accuracy.
|
||||||
|
|
||||||
Here, we also show how `--beam-size` effect the WER and decoding time:
|
Here, we also show how `--beam-size` effect the WER and decoding time:
|
||||||
@ -176,4 +176,4 @@ As we see, a larger beam size during shallow fusion improves the WER, but is als
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -74,6 +74,10 @@ to install dependencies of `icefall`_:
|
|||||||
|
|
||||||
pip install k2==1.24.4.dev20231220+cpu.torch2.0.0 -f https://k2-fsa.github.io/k2/cpu.html
|
pip install k2==1.24.4.dev20231220+cpu.torch2.0.0 -f https://k2-fsa.github.io/k2/cpu.html
|
||||||
|
|
||||||
|
# For users from China
|
||||||
|
# 中国国内用户,如果访问不了 huggingface, 请使用
|
||||||
|
# pip install k2==1.24.4.dev20231220+cpu.torch2.0.0 -f https://k2-fsa.github.io/k2/cpu-cn.html
|
||||||
|
|
||||||
# Install the latest version of lhotse
|
# Install the latest version of lhotse
|
||||||
|
|
||||||
pip install git+https://github.com/lhotse-speech/lhotse
|
pip install git+https://github.com/lhotse-speech/lhotse
|
||||||
|
@ -206,6 +206,9 @@ We will install `k2`_ from pre-compiled wheels by following
|
|||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
(test-icefall) kuangfangjun:~$ pip install k2==1.24.3.dev20230725+cuda11.6.torch1.13.0 -f https://k2-fsa.github.io/k2/cuda.html
|
(test-icefall) kuangfangjun:~$ pip install k2==1.24.3.dev20230725+cuda11.6.torch1.13.0 -f https://k2-fsa.github.io/k2/cuda.html
|
||||||
|
# For users from China
|
||||||
|
# 中国国内用户,如果访问不了 huggingface, 请使用
|
||||||
|
# pip install k2==1.24.3.dev20230725+cuda11.6.torch1.13.0 -f https://k2-fsa.github.io/k2/cuda-cn.html
|
||||||
|
|
||||||
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
|
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
Looking in links: https://k2-fsa.github.io/k2/cuda.html
|
Looking in links: https://k2-fsa.github.io/k2/cuda.html
|
||||||
|
225
docs/source/recipes/Finetune/adapter/finetune_adapter.rst
Normal file
225
docs/source/recipes/Finetune/adapter/finetune_adapter.rst
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
Finetune from a pre-trained Zipformer model with adapters
|
||||||
|
=========================================================
|
||||||
|
|
||||||
|
This tutorial shows you how to fine-tune a pre-trained **Zipformer**
|
||||||
|
transducer model on a new dataset with adapters.
|
||||||
|
Adapters are compact and efficient module that can be integrated into a pre-trained model
|
||||||
|
to improve the model's performance on a new domain. Adapters are injected
|
||||||
|
between different modules in the well-trained neural network. During training, only the parameters
|
||||||
|
in the adapters will be updated. It achieves competitive performance
|
||||||
|
while requiring much less GPU memory than full fine-tuning. For more details about adapters,
|
||||||
|
please refer to the original `paper <https://arxiv.org/pdf/1902.00751.pdf#/>`_ for more details.
|
||||||
|
|
||||||
|
.. HINT::
|
||||||
|
|
||||||
|
We assume you have read the page :ref:`install icefall` and have setup
|
||||||
|
the environment for ``icefall``.
|
||||||
|
|
||||||
|
.. HINT::
|
||||||
|
|
||||||
|
We recommend you to use a GPU or several GPUs to run this recipe
|
||||||
|
|
||||||
|
For illustration purpose, we fine-tune the Zipformer transducer model
|
||||||
|
pre-trained on `LibriSpeech`_ on the small subset of `GigaSpeech`_. You could use your
|
||||||
|
own data for fine-tuning if you create a manifest for your new dataset.
|
||||||
|
|
||||||
|
Data preparation
|
||||||
|
----------------
|
||||||
|
|
||||||
|
Please follow the instructions in the `GigaSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/gigaspeech/ASR>`_
|
||||||
|
to prepare the fine-tune data used in this tutorial. We only require the small subset in GigaSpeech for this tutorial.
|
||||||
|
|
||||||
|
|
||||||
|
Model preparation
|
||||||
|
-----------------
|
||||||
|
|
||||||
|
We are using the Zipformer model trained on full LibriSpeech (960 hours) as the intialization. The
|
||||||
|
checkpoint of the model can be downloaded via the following command:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||||
|
$ cd icefall-asr-librispeech-zipformer-2023-05-15/exp
|
||||||
|
$ git lfs pull --include "pretrained.pt"
|
||||||
|
$ ln -s pretrained.pt epoch-99.pt
|
||||||
|
$ cd ../data/lang_bpe_500
|
||||||
|
$ git lfs pull --include bpe.model
|
||||||
|
$ cd ../../..
|
||||||
|
|
||||||
|
Before fine-tuning, let's test the model's WER on the new domain. The following command performs
|
||||||
|
decoding on the GigaSpeech test sets:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
./zipformer/decode_gigaspeech.py \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir icefall-asr-librispeech-zipformer-2023-05-15/exp \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--max-duration 1000 \
|
||||||
|
--decoding-method greedy_search
|
||||||
|
|
||||||
|
You should see the following numbers:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
For dev, WER of different settings are:
|
||||||
|
greedy_search 20.06 best for dev
|
||||||
|
|
||||||
|
For test, WER of different settings are:
|
||||||
|
greedy_search 19.27 best for test
|
||||||
|
|
||||||
|
|
||||||
|
Fine-tune with adapter
|
||||||
|
----------------------
|
||||||
|
|
||||||
|
We insert 4 adapters with residual connection in each ``Zipformer2EncoderLayer``.
|
||||||
|
The original model parameters remain untouched during training and only the parameters of
|
||||||
|
the adapters are updated. The following command starts a fine-tuning experiment with adapters:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ do_finetune=1
|
||||||
|
$ use_adapters=1
|
||||||
|
$ adapter_dim=8
|
||||||
|
|
||||||
|
$ ./zipformer_adapter/train.py \
|
||||||
|
--world-size 2 \
|
||||||
|
--num-epochs 20 \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--base-lr 0.045 \
|
||||||
|
--use-adapters $use_adapters --adapter-dim $adapter_dim \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--do-finetune $do_finetune \
|
||||||
|
--master-port 13022 \
|
||||||
|
--finetune-ckpt icefall-asr-librispeech-zipformer-2023-05-15/exp/pretrained.pt \
|
||||||
|
--max-duration 1000
|
||||||
|
|
||||||
|
The following arguments are related to fine-tuning:
|
||||||
|
|
||||||
|
- ``--do-finetune``
|
||||||
|
If True, do fine-tuning by initializing the model from a pre-trained checkpoint.
|
||||||
|
**Note that if you want to resume your fine-tuning experiment from certain epochs, you
|
||||||
|
need to set this to False.**
|
||||||
|
|
||||||
|
- ``use-adapters``
|
||||||
|
If adapters are used during fine-tuning.
|
||||||
|
|
||||||
|
- ``--adapter-dim``
|
||||||
|
The bottleneck dimension of the adapter module. Typically a small number.
|
||||||
|
|
||||||
|
You should notice that in the training log, the total number of trainale parameters is shown:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
2024-02-22 21:22:03,808 INFO [train.py:1277] A total of 761344 trainable parameters (1.148% of the whole model)
|
||||||
|
|
||||||
|
The trainable parameters only makes up 1.15% of the entire model parameters, so the training will be much faster
|
||||||
|
and requires less memory than full fine-tuning.
|
||||||
|
|
||||||
|
|
||||||
|
Decoding
|
||||||
|
--------
|
||||||
|
|
||||||
|
After training, let's test the WERs. To test the WERs on the GigaSpeech set,
|
||||||
|
you can execute the following command:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ epoch=20
|
||||||
|
$ avg=10
|
||||||
|
$ use_adapters=1
|
||||||
|
$ adapter_dim=8
|
||||||
|
|
||||||
|
% ./zipformer/decode.py \
|
||||||
|
--epoch $epoch \
|
||||||
|
--avg $avg \
|
||||||
|
--use-averaged-model 1 \
|
||||||
|
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
|
||||||
|
--max-duration 600 \
|
||||||
|
--use-adapters $use_adapters \
|
||||||
|
--adapter-dim $adapter_dim \
|
||||||
|
--decoding-method greedy_search
|
||||||
|
|
||||||
|
You should see the following numbers:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
For dev, WER of different settings are:
|
||||||
|
greedy_search 15.44 best for dev
|
||||||
|
|
||||||
|
For test, WER of different settings are:
|
||||||
|
greedy_search 15.42 best for test
|
||||||
|
|
||||||
|
|
||||||
|
The WER on test set is improved from 19.27 to 15.42, demonstrating the effectiveness of adapters.
|
||||||
|
|
||||||
|
The same model can be used to perform decoding on LibriSpeech test sets. You can deactivate the adapters
|
||||||
|
to keep the same performance of the original model:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ epoch=20
|
||||||
|
$ avg=1
|
||||||
|
$ use_adapters=0
|
||||||
|
$ adapter_dim=8
|
||||||
|
|
||||||
|
% ./zipformer/decode.py \
|
||||||
|
--epoch $epoch \
|
||||||
|
--avg $avg \
|
||||||
|
--use-averaged-model 1 \
|
||||||
|
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
|
||||||
|
--max-duration 600 \
|
||||||
|
--use-adapters $use_adapters \
|
||||||
|
--adapter-dim $adapter_dim \
|
||||||
|
--decoding-method greedy_search
|
||||||
|
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
For dev, WER of different settings are:
|
||||||
|
greedy_search 2.23 best for test-clean
|
||||||
|
|
||||||
|
For test, WER of different settings are:
|
||||||
|
greedy_search 4.96 best for test-other
|
||||||
|
|
||||||
|
The numbers are the same as reported in `icefall <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md#normal-scaled-model-number-of-model-parameters-65549011-ie-6555-m>`_. So adapter-based
|
||||||
|
fine-tuning is also very flexible as the same model can be used for decoding on the original and target domain.
|
||||||
|
|
||||||
|
|
||||||
|
Export the model
|
||||||
|
----------------
|
||||||
|
|
||||||
|
After training, the model can be exported to ``onnx`` format easily using the following command:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ use_adapters=1
|
||||||
|
$ adapter_dim=16
|
||||||
|
|
||||||
|
$ ./zipformer_adapter/export-onnx.py \
|
||||||
|
--tokens icefall-asr-librispeech-zipformer-2023-05-15/data/lang_bpe_500/tokens.txt \
|
||||||
|
--use-averaged-model 1 \
|
||||||
|
--epoch 20 \
|
||||||
|
--avg 10 \
|
||||||
|
--exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
|
||||||
|
--use-adapters $use_adapters \
|
||||||
|
--adapter-dim $adapter_dim \
|
||||||
|
--num-encoder-layers "2,2,3,4,3,2" \
|
||||||
|
--downsampling-factor "1,2,4,8,4,2" \
|
||||||
|
--feedforward-dim "512,768,1024,1536,1024,768" \
|
||||||
|
--num-heads "4,4,4,8,4,4" \
|
||||||
|
--encoder-dim "192,256,384,512,384,256" \
|
||||||
|
--query-head-dim 32 \
|
||||||
|
--value-head-dim 12 \
|
||||||
|
--pos-head-dim 4 \
|
||||||
|
--pos-dim 48 \
|
||||||
|
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
||||||
|
--cnn-module-kernel "31,31,15,15,15,31" \
|
||||||
|
--decoder-dim 512 \
|
||||||
|
--joiner-dim 512 \
|
||||||
|
--causal False \
|
||||||
|
--chunk-size "16,32,64,-1" \
|
||||||
|
--left-context-frames "64,128,256,-1"
|
@ -0,0 +1,140 @@
|
|||||||
|
Finetune from a supervised pre-trained Zipformer model
|
||||||
|
======================================================
|
||||||
|
|
||||||
|
This tutorial shows you how to fine-tune a supervised pre-trained **Zipformer**
|
||||||
|
transducer model on a new dataset.
|
||||||
|
|
||||||
|
.. HINT::
|
||||||
|
|
||||||
|
We assume you have read the page :ref:`install icefall` and have setup
|
||||||
|
the environment for ``icefall``.
|
||||||
|
|
||||||
|
.. HINT::
|
||||||
|
|
||||||
|
We recommend you to use a GPU or several GPUs to run this recipe
|
||||||
|
|
||||||
|
|
||||||
|
For illustration purpose, we fine-tune the Zipformer transducer model
|
||||||
|
pre-trained on `LibriSpeech`_ on the small subset of `GigaSpeech`_. You could use your
|
||||||
|
own data for fine-tuning if you create a manifest for your new dataset.
|
||||||
|
|
||||||
|
Data preparation
|
||||||
|
----------------
|
||||||
|
|
||||||
|
Please follow the instructions in the `GigaSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/gigaspeech/ASR>`_
|
||||||
|
to prepare the fine-tune data used in this tutorial. We only require the small subset in GigaSpeech for this tutorial.
|
||||||
|
|
||||||
|
|
||||||
|
Model preparation
|
||||||
|
-----------------
|
||||||
|
|
||||||
|
We are using the Zipformer model trained on full LibriSpeech (960 hours) as the intialization. The
|
||||||
|
checkpoint of the model can be downloaded via the following command:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||||
|
$ cd icefall-asr-librispeech-zipformer-2023-05-15/exp
|
||||||
|
$ git lfs pull --include "pretrained.pt"
|
||||||
|
$ ln -s pretrained.pt epoch-99.pt
|
||||||
|
$ cd ../data/lang_bpe_500
|
||||||
|
$ git lfs pull --include bpe.model
|
||||||
|
$ cd ../../..
|
||||||
|
|
||||||
|
Before fine-tuning, let's test the model's WER on the new domain. The following command performs
|
||||||
|
decoding on the GigaSpeech test sets:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
./zipformer/decode_gigaspeech.py \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir icefall-asr-librispeech-zipformer-2023-05-15/exp \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--max-duration 1000 \
|
||||||
|
--decoding-method greedy_search
|
||||||
|
|
||||||
|
You should see the following numbers:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
For dev, WER of different settings are:
|
||||||
|
greedy_search 20.06 best for dev
|
||||||
|
|
||||||
|
For test, WER of different settings are:
|
||||||
|
greedy_search 19.27 best for test
|
||||||
|
|
||||||
|
|
||||||
|
Fine-tune
|
||||||
|
---------
|
||||||
|
|
||||||
|
Since LibriSpeech and GigaSpeech are both English dataset, we can initialize the whole
|
||||||
|
Zipformer model with the checkpoint downloaded in the previous step (otherwise we should consider
|
||||||
|
initializing the stateless decoder and joiner from scratch due to the mismatch of the output
|
||||||
|
vocabulary). The following command starts a fine-tuning experiment:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ use_mux=0
|
||||||
|
$ do_finetune=1
|
||||||
|
|
||||||
|
$ ./zipformer/finetune.py \
|
||||||
|
--world-size 2 \
|
||||||
|
--num-epochs 20 \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--exp-dir zipformer/exp_giga_finetune${do_finetune}_mux${use_mux} \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--base-lr 0.0045 \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--do-finetune $do_finetune \
|
||||||
|
--use-mux $use_mux \
|
||||||
|
--master-port 13024 \
|
||||||
|
--finetune-ckpt icefall-asr-librispeech-zipformer-2023-05-15/exp/pretrained.pt \
|
||||||
|
--max-duration 1000
|
||||||
|
|
||||||
|
The following arguments are related to fine-tuning:
|
||||||
|
|
||||||
|
- ``--base-lr``
|
||||||
|
The learning rate used for fine-tuning. We suggest to set a **small** learning rate for fine-tuning,
|
||||||
|
otherwise the model may forget the initialization very quickly. A reasonable value should be around
|
||||||
|
1/10 of the original lr, i.e 0.0045.
|
||||||
|
|
||||||
|
- ``--do-finetune``
|
||||||
|
If True, do fine-tuning by initializing the model from a pre-trained checkpoint.
|
||||||
|
**Note that if you want to resume your fine-tuning experiment from certain epochs, you
|
||||||
|
need to set this to False.**
|
||||||
|
|
||||||
|
- ``--finetune-ckpt``
|
||||||
|
The path to the pre-trained checkpoint (used for initialization).
|
||||||
|
|
||||||
|
- ``--use-mux``
|
||||||
|
If True, mix the fine-tune data with the original training data by using `CutSet.mux <https://lhotse.readthedocs.io/en/latest/api.html#lhotse.supervision.SupervisionSet.mux>`_
|
||||||
|
This helps maintain the model's performance on the original domain if the original training
|
||||||
|
is available. **If you don't have the original training data, please set it to False.**
|
||||||
|
|
||||||
|
After fine-tuning, let's test the WERs. You can do this via the following command:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ use_mux=0
|
||||||
|
$ do_finetune=1
|
||||||
|
$ ./zipformer/decode_gigaspeech.py \
|
||||||
|
--epoch 20 \
|
||||||
|
--avg 10 \
|
||||||
|
--exp-dir zipformer/exp_giga_finetune${do_finetune}_mux${use_mux} \
|
||||||
|
--use-averaged-model 1 \
|
||||||
|
--max-duration 1000 \
|
||||||
|
--decoding-method greedy_search
|
||||||
|
|
||||||
|
You should see numbers similar to the ones below:
|
||||||
|
|
||||||
|
.. code-block:: text
|
||||||
|
|
||||||
|
For dev, WER of different settings are:
|
||||||
|
greedy_search 13.47 best for dev
|
||||||
|
|
||||||
|
For test, WER of different settings are:
|
||||||
|
greedy_search 13.66 best for test
|
||||||
|
|
||||||
|
Compared to the original checkpoint, the fine-tuned model achieves much lower WERs
|
||||||
|
on the GigaSpeech test sets.
|
16
docs/source/recipes/Finetune/index.rst
Normal file
16
docs/source/recipes/Finetune/index.rst
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
Fine-tune a pre-trained model
|
||||||
|
=============================
|
||||||
|
|
||||||
|
After pre-training on public available datasets, the ASR model is already capable of
|
||||||
|
performing general speech recognition with relatively high accuracy. However, the accuracy
|
||||||
|
could be still low on certain domains that are quite different from the original training
|
||||||
|
set. In this case, we can fine-tune the model with a small amount of additional labelled
|
||||||
|
data to improve the performance on new domains.
|
||||||
|
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 2
|
||||||
|
:caption: Table of Contents
|
||||||
|
|
||||||
|
from_supervised/finetune_zipformer
|
||||||
|
adapter/finetune_adapter
|
@ -1,11 +1,11 @@
|
|||||||
VITS
|
VITS-LJSpeech
|
||||||
===============
|
===============
|
||||||
|
|
||||||
This tutorial shows you how to train an VITS model
|
This tutorial shows you how to train an VITS model
|
||||||
with the `LJSpeech <https://keithito.com/LJ-Speech-Dataset/>`_ dataset.
|
with the `LJSpeech <https://keithito.com/LJ-Speech-Dataset/>`_ dataset.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
TTS related recipes require packages in ``requirements-tts.txt``.
|
TTS related recipes require packages in ``requirements-tts.txt``.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
@ -13,6 +13,14 @@ with the `LJSpeech <https://keithito.com/LJ-Speech-Dataset/>`_ dataset.
|
|||||||
The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech <https://arxiv.org/pdf/2106.06103.pdf>`_
|
The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech <https://arxiv.org/pdf/2106.06103.pdf>`_
|
||||||
|
|
||||||
|
|
||||||
|
Install extra dependencies
|
||||||
|
--------------------------
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html
|
||||||
|
pip install numba espnet_tts_frontend
|
||||||
|
|
||||||
Data preparation
|
Data preparation
|
||||||
----------------
|
----------------
|
||||||
|
|
||||||
@ -56,7 +64,8 @@ Training
|
|||||||
--start-epoch 1 \
|
--start-epoch 1 \
|
||||||
--use-fp16 1 \
|
--use-fp16 1 \
|
||||||
--exp-dir vits/exp \
|
--exp-dir vits/exp \
|
||||||
--tokens data/tokens.txt
|
--tokens data/tokens.txt \
|
||||||
|
--model-type high \
|
||||||
--max-duration 500
|
--max-duration 500
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
@ -64,6 +73,11 @@ Training
|
|||||||
You can adjust the hyper-parameters to control the size of the VITS model and
|
You can adjust the hyper-parameters to control the size of the VITS model and
|
||||||
the training configurations. For more details, please run ``./vits/train.py --help``.
|
the training configurations. For more details, please run ``./vits/train.py --help``.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
If you want a model that runs faster on CPU, please use ``--model-type low``
|
||||||
|
or ``--model-type medium``.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
The training can take a long time (usually a couple of days).
|
The training can take a long time (usually a couple of days).
|
||||||
@ -95,8 +109,8 @@ training part first. It will save the ground-truth and generated wavs to the dir
|
|||||||
Export models
|
Export models
|
||||||
-------------
|
-------------
|
||||||
|
|
||||||
Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``:
|
Currently we only support ONNX model exporting. It will generate one file in the given ``exp-dir``:
|
||||||
``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``.
|
``vits-epoch-*.onnx``.
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
@ -120,4 +134,68 @@ Download pretrained models
|
|||||||
If you don't want to train from scratch, you can download the pretrained models
|
If you don't want to train from scratch, you can download the pretrained models
|
||||||
by visiting the following link:
|
by visiting the following link:
|
||||||
|
|
||||||
- `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2023-11-29>`_
|
- ``--model-type=high``: `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2024-02-28>`_
|
||||||
|
- ``--model-type=medium``: `<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-medium-2024-03-12>`_
|
||||||
|
- ``--model-type=low``: `<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-low-2024-03-12>`_
|
||||||
|
|
||||||
|
Usage in sherpa-onnx
|
||||||
|
--------------------
|
||||||
|
|
||||||
|
The following describes how to test the exported ONNX model in `sherpa-onnx`_.
|
||||||
|
|
||||||
|
.. hint::
|
||||||
|
|
||||||
|
`sherpa-onnx`_ supports different programming languages, e.g., C++, C, Python,
|
||||||
|
Kotlin, Java, Swift, Go, C#, etc. It also supports Android and iOS.
|
||||||
|
|
||||||
|
We only describe how to use pre-built binaries from `sherpa-onnx`_ below.
|
||||||
|
Please refer to `<https://k2-fsa.github.io/sherpa/onnx/>`_
|
||||||
|
for more documentation.
|
||||||
|
|
||||||
|
Install sherpa-onnx
|
||||||
|
^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
pip install sherpa-onnx
|
||||||
|
|
||||||
|
To check that you have installed `sherpa-onnx`_ successfully, please run:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
which sherpa-onnx-offline-tts
|
||||||
|
sherpa-onnx-offline-tts --help
|
||||||
|
|
||||||
|
Download lexicon files
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
cd /tmp
|
||||||
|
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
|
||||||
|
tar xf espeak-ng-data.tar.bz2
|
||||||
|
|
||||||
|
Run sherpa-onnx
|
||||||
|
^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
cd egs/ljspeech/TTS
|
||||||
|
|
||||||
|
sherpa-onnx-offline-tts \
|
||||||
|
--vits-model=vits/exp/vits-epoch-1000.onnx \
|
||||||
|
--vits-tokens=data/tokens.txt \
|
||||||
|
--vits-data-dir=/tmp/espeak-ng-data \
|
||||||
|
--num-threads=1 \
|
||||||
|
--output-filename=./high.wav \
|
||||||
|
"Ask not what your country can do for you; ask what you can do for your country."
|
||||||
|
|
||||||
|
.. hint::
|
||||||
|
|
||||||
|
You can also use ``sherpa-onnx-offline-tts-play`` to play the audio
|
||||||
|
as it is generating.
|
||||||
|
|
||||||
|
You should get a file ``high.wav`` after running the above command.
|
||||||
|
|
||||||
|
Congratulations! You have successfully trained and exported a text-to-speech
|
||||||
|
model and run it with `sherpa-onnx`_.
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
VITS
|
VITS-VCTK
|
||||||
===============
|
===============
|
||||||
|
|
||||||
This tutorial shows you how to train an VITS model
|
This tutorial shows you how to train an VITS model
|
||||||
with the `VCTK <https://datashare.ed.ac.uk/handle/10283/3443>`_ dataset.
|
with the `VCTK <https://datashare.ed.ac.uk/handle/10283/3443>`_ dataset.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
TTS related recipes require packages in ``requirements-tts.txt``.
|
TTS related recipes require packages in ``requirements-tts.txt``.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
@ -17,3 +17,4 @@ We may add recipes for other tasks as well in the future.
|
|||||||
Streaming-ASR/index
|
Streaming-ASR/index
|
||||||
RNN-LM/index
|
RNN-LM/index
|
||||||
TTS/index
|
TTS/index
|
||||||
|
Finetune/index
|
||||||
|
@ -16,8 +16,8 @@ perturb_speed=true
|
|||||||
#
|
#
|
||||||
# - $dl_dir/aidatatang_200zh
|
# - $dl_dir/aidatatang_200zh
|
||||||
# You can find "corpus" and "transcript" inside it.
|
# You can find "corpus" and "transcript" inside it.
|
||||||
# You can download it at
|
# You can download it at https://openslr.org/62/
|
||||||
# https://openslr.org/62/
|
# If you download the data by yourself, DON'T FORGET to extract the *.tar.gz files under corpus.
|
||||||
|
|
||||||
dl_dir=$PWD/download
|
dl_dir=$PWD/download
|
||||||
|
|
||||||
|
@ -19,7 +19,9 @@ The following table lists the differences among them.
|
|||||||
| `transducer_stateless_modified` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` |
|
| `transducer_stateless_modified` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` |
|
||||||
| `transducer_stateless_modified-2` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` + extra data |
|
| `transducer_stateless_modified-2` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` + extra data |
|
||||||
| `pruned_transducer_stateless3` | Conformer (reworked) | Embedding + Conv1d | pruned RNN-T + reworked model with random combiner + using aidatatang_20zh as extra data|
|
| `pruned_transducer_stateless3` | Conformer (reworked) | Embedding + Conv1d | pruned RNN-T + reworked model with random combiner + using aidatatang_20zh as extra data|
|
||||||
| `pruned_transducer_stateless7` | Zipformer | Embedding | pruned RNN-T + zipformer encoder + stateless decoder with context-size 1 |
|
| `pruned_transducer_stateless7` | Zipformer | Embedding | pruned RNN-T + zipformer encoder + stateless decoder with context-size set to 1 |
|
||||||
|
| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe with context-size set to 1 |
|
||||||
|
|
||||||
|
|
||||||
The decoder in `transducer_stateless` is modified from the paper
|
The decoder in `transducer_stateless` is modified from the paper
|
||||||
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
|
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
|
||||||
|
@ -75,7 +75,7 @@ It's reworked Zipformer with Pruned RNNT loss, trained with Byte-level BPE, `voc
|
|||||||
| fast beam search | 4.43 | 4.17 | --epoch 40 --avg 10 |
|
| fast beam search | 4.43 | 4.17 | --epoch 40 --avg 10 |
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./prepare.sh
|
./prepare.sh
|
||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES="0,1"
|
export CUDA_VISIBLE_DEVICES="0,1"
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
|
|
||||||
Please visit
|
Please visit
|
||||||
<https://icefall.readthedocs.io/en/latest/recipes/aishell/conformer_ctc.html>
|
<https://k2-fsa.github.io/icefall/recipes/Non-streaming-ASR/aishell/conformer_ctc.html>
|
||||||
for how to run this recipe.
|
for how to run this recipe.
|
||||||
|
@ -419,7 +419,7 @@ def save_results(
|
|||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results, char_level=True)
|
||||||
if enable_log:
|
if enable_log:
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -432,7 +432,11 @@ def save_results(
|
|||||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
|
f,
|
||||||
|
f"{test_set_name}-{key}",
|
||||||
|
results_char,
|
||||||
|
enable_log=enable_log,
|
||||||
|
compute_CER=True,
|
||||||
)
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
|
@ -431,7 +431,7 @@ def save_results(
|
|||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results, char_level=True)
|
||||||
if enable_log:
|
if enable_log:
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -444,7 +444,11 @@ def save_results(
|
|||||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
|
f,
|
||||||
|
f"{test_set_name}-{key}",
|
||||||
|
results_char,
|
||||||
|
enable_log=enable_log,
|
||||||
|
compute_CER=True,
|
||||||
)
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
|
@ -360,7 +360,7 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
||||||
log "Stage 11: Train RNN LM model"
|
log "Stage 12: Train RNN LM model"
|
||||||
python ../../../icefall/rnn_lm/train.py \
|
python ../../../icefall/rnn_lm/train.py \
|
||||||
--start-epoch 0 \
|
--start-epoch 0 \
|
||||||
--world-size 1 \
|
--world-size 1 \
|
||||||
|
@ -390,7 +390,7 @@ def save_results(
|
|||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results, char_level=True)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
@ -402,7 +402,11 @@ def save_results(
|
|||||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
f,
|
||||||
|
f"{test_set_name}-{key}",
|
||||||
|
results_char,
|
||||||
|
enable_log=True,
|
||||||
|
compute_CER=True,
|
||||||
)
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
|
@ -526,7 +526,7 @@ def save_results(
|
|||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results, char_level=True)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
@ -538,7 +538,11 @@ def save_results(
|
|||||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
f,
|
||||||
|
f"{test_set_name}-{key}",
|
||||||
|
results_char,
|
||||||
|
enable_log=True,
|
||||||
|
compute_CER=True,
|
||||||
)
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
|
@ -444,7 +444,7 @@ def save_results(
|
|||||||
for res in results:
|
for res in results:
|
||||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||||
|
|
||||||
store_transcripts(filename=recog_path, texts=results_char)
|
store_transcripts(filename=recog_path, texts=results_char, char_level=True)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
@ -452,7 +452,11 @@ def save_results(
|
|||||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
f,
|
||||||
|
f"{test_set_name}-{key}",
|
||||||
|
results_char,
|
||||||
|
enable_log=True,
|
||||||
|
compute_CER=True,
|
||||||
)
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
|
@ -89,6 +89,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -881,9 +882,7 @@ def train_one_epoch(
|
|||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error()
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
||||||
|
@ -85,6 +85,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
@ -878,9 +879,7 @@ def train_one_epoch(
|
|||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
||||||
|
@ -581,7 +581,7 @@ def save_results(
|
|||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results, char_level=True)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
@ -594,7 +594,11 @@ def save_results(
|
|||||||
|
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
f,
|
||||||
|
f"{test_set_name}-{key}",
|
||||||
|
results_char,
|
||||||
|
enable_log=True,
|
||||||
|
compute_CER=True,
|
||||||
)
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
|
@ -78,6 +78,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -871,9 +872,7 @@ def train_one_epoch(
|
|||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
|
@ -250,7 +250,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -492,7 +492,7 @@ def save_results(
|
|||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results, char_level=True)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
@ -500,7 +500,11 @@ def save_results(
|
|||||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f,
|
||||||
|
f"{test_set_name}-{key}",
|
||||||
|
results,
|
||||||
|
enable_log=True,
|
||||||
|
compute_CER=True,
|
||||||
)
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
|
@ -78,6 +78,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
@ -882,9 +883,7 @@ def train_one_epoch(
|
|||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
|
@ -78,6 +78,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
@ -881,9 +882,7 @@ def train_one_epoch(
|
|||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
|
@ -278,7 +278,7 @@ def save_results(
|
|||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results, char_level=True)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
@ -289,7 +289,13 @@ def save_results(
|
|||||||
for res in results:
|
for res in results:
|
||||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(f, f"{test_set_name}-{key}", results_char)
|
wer = write_error_stats(
|
||||||
|
f,
|
||||||
|
f"{test_set_name}-{key}",
|
||||||
|
results_char,
|
||||||
|
enable_log=True,
|
||||||
|
compute_CER=True,
|
||||||
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
@ -327,7 +327,7 @@ def save_results(
|
|||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results, char_level=True)
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
@ -338,7 +338,11 @@ def save_results(
|
|||||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
f,
|
||||||
|
f"{test_set_name}-{key}",
|
||||||
|
results_char,
|
||||||
|
enable_log=True,
|
||||||
|
compute_CER=True,
|
||||||
)
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
|
@ -372,7 +372,7 @@ def save_results(
|
|||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results, char_level=True)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
@ -384,7 +384,11 @@ def save_results(
|
|||||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
f,
|
||||||
|
f"{test_set_name}-{key}",
|
||||||
|
results_char,
|
||||||
|
enable_log=True,
|
||||||
|
compute_CER=True,
|
||||||
)
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
|
@ -376,7 +376,7 @@ def save_results(
|
|||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results, char_level=True)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
@ -388,7 +388,11 @@ def save_results(
|
|||||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
f,
|
||||||
|
f"{test_set_name}-{key}",
|
||||||
|
results_char,
|
||||||
|
enable_log=True,
|
||||||
|
compute_CER=True,
|
||||||
)
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
|
@ -214,7 +214,7 @@ def get_parser():
|
|||||||
"--model-name",
|
"--model-name",
|
||||||
type=str,
|
type=str,
|
||||||
default="large-v2",
|
default="large-v2",
|
||||||
choices=["large-v2", "large-v3", "medium", "small", "tiny"],
|
choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"],
|
||||||
help="""The model name to use.
|
help="""The model name to use.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -358,7 +358,7 @@ def save_results(
|
|||||||
params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results, char_level=True)
|
||||||
if enable_log:
|
if enable_log:
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
@ -373,7 +373,11 @@ def save_results(
|
|||||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
|
f,
|
||||||
|
f"{test_set_name}-{key}",
|
||||||
|
results_char,
|
||||||
|
enable_log=enable_log,
|
||||||
|
compute_CER=True,
|
||||||
)
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@
|
|||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
#fine-tuning with deepspeed zero stage 1
|
#fine-tuning with deepspeed zero stage 1
|
||||||
torchrun --nproc-per-node 8 ./whisper/train.py \
|
torchrun --nproc_per_node 8 ./whisper/train.py \
|
||||||
--max-duration 200 \
|
--max-duration 200 \
|
||||||
--exp-dir whisper/exp_large_v2 \
|
--exp-dir whisper/exp_large_v2 \
|
||||||
--model-name large-v2 \
|
--model-name large-v2 \
|
||||||
@ -28,7 +28,7 @@ torchrun --nproc-per-node 8 ./whisper/train.py \
|
|||||||
--deepspeed_config ./whisper/ds_config_zero1.json
|
--deepspeed_config ./whisper/ds_config_zero1.json
|
||||||
|
|
||||||
# fine-tuning with ddp
|
# fine-tuning with ddp
|
||||||
torchrun --nproc-per-node 8 ./whisper/train.py \
|
torchrun --nproc_per_node 8 ./whisper/train.py \
|
||||||
--max-duration 200 \
|
--max-duration 200 \
|
||||||
--exp-dir whisper/exp_medium \
|
--exp-dir whisper/exp_medium \
|
||||||
--manifest-dir data/fbank_whisper \
|
--manifest-dir data/fbank_whisper \
|
||||||
@ -136,7 +136,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="pruned_transducer_stateless7/exp",
|
default="whisper/exp",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
@ -147,7 +147,7 @@ def get_parser():
|
|||||||
"--model-name",
|
"--model-name",
|
||||||
type=str,
|
type=str,
|
||||||
default="large-v2",
|
default="large-v2",
|
||||||
choices=["large-v2", "large-v3", "medium", "small", "tiny"],
|
choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"],
|
||||||
help="""The model name to use.
|
help="""The model name to use.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -793,7 +793,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -560,7 +560,7 @@ def save_results(
|
|||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results, char_level=True)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
@ -570,7 +570,11 @@ def save_results(
|
|||||||
)
|
)
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f,
|
||||||
|
f"{test_set_name}-{key}",
|
||||||
|
results,
|
||||||
|
enable_log=True,
|
||||||
|
compute_CER=True,
|
||||||
)
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
|
@ -86,6 +86,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
@ -985,9 +986,7 @@ def train_one_epoch(
|
|||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
save_bad_model()
|
save_bad_model()
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = max(scheduler.get_last_lr())
|
cur_lr = max(scheduler.get_last_lr())
|
||||||
|
@ -83,6 +83,7 @@ from icefall.checkpoint import (
|
|||||||
update_averaged_model,
|
update_averaged_model,
|
||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -570,9 +571,7 @@ def train_one_epoch(
|
|||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
save_bad_model()
|
save_bad_model()
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = max(scheduler.get_last_lr())
|
cur_lr = max(scheduler.get_last_lr())
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
## Results
|
## Results
|
||||||
|
|
||||||
### Aishell2 char-based training results
|
### Aishell2 char-based training results
|
||||||
|
|
||||||
#### Pruned transducer stateless 5
|
#### Pruned transducer stateless 5
|
||||||
|
|
||||||
|
@ -29,7 +29,14 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
from lhotse import (
|
||||||
|
CutSet,
|
||||||
|
Fbank,
|
||||||
|
FbankConfig,
|
||||||
|
LilcomChunkyWriter,
|
||||||
|
WhisperFbank,
|
||||||
|
WhisperFbankConfig,
|
||||||
|
)
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor, str2bool
|
from icefall.utils import get_executor, str2bool
|
||||||
@ -42,10 +49,12 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
|
def compute_fbank_aishell2(
|
||||||
|
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
|
||||||
|
):
|
||||||
src_dir = Path("data/manifests")
|
src_dir = Path("data/manifests")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(8, os.cpu_count())
|
||||||
|
|
||||||
dataset_parts = (
|
dataset_parts = (
|
||||||
"train",
|
"train",
|
||||||
@ -68,8 +77,12 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
|
|||||||
list(manifests.keys()),
|
list(manifests.keys()),
|
||||||
dataset_parts,
|
dataset_parts,
|
||||||
)
|
)
|
||||||
|
if whisper_fbank:
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
extractor = WhisperFbank(
|
||||||
|
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
with get_executor() as ex: # Initialize the executor only once.
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
for partition, m in manifests.items():
|
for partition, m in manifests.items():
|
||||||
@ -82,7 +95,7 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
|
|||||||
supervisions=m["supervisions"],
|
supervisions=m["supervisions"],
|
||||||
)
|
)
|
||||||
if "train" in partition and perturb_speed:
|
if "train" in partition and perturb_speed:
|
||||||
logging.info(f"Doing speed perturb")
|
logging.info("Doing speed perturb")
|
||||||
cut_set = (
|
cut_set = (
|
||||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
@ -111,7 +124,12 @@ def get_args():
|
|||||||
default=False,
|
default=False,
|
||||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-fbank",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -122,5 +140,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_aishell2(
|
compute_fbank_aishell2(
|
||||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
num_mel_bins=args.num_mel_bins,
|
||||||
|
perturb_speed=args.perturb_speed,
|
||||||
|
whisper_fbank=args.whisper_fbank,
|
||||||
)
|
)
|
||||||
|
@ -108,6 +108,16 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
|||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
whisper_mel_bins=80
|
||||||
|
if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then
|
||||||
|
log "Stage 30: Compute whisper fbank for aishell2"
|
||||||
|
if [ ! -f data/fbank/.aishell2.whisper.done ]; then
|
||||||
|
mkdir -p data/fbank
|
||||||
|
./local/compute_fbank_aishell2.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||||
|
touch data/fbank/.aishell2.whisper.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
log "Stage 4: Compute fbank for musan"
|
log "Stage 4: Compute fbank for musan"
|
||||||
if [ ! -f data/fbank/.msuan.done ]; then
|
if [ ! -f data/fbank/.msuan.done ]; then
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
This recipe contains some various ASR models trained with Aishell4 (including S, M and L three subsets).
|
This recipe contains some various ASR models trained with Aishell4 (including S, M and L three subsets).
|
||||||
|
|
||||||
The AISHELL-4 is a sizable real-recorded Mandarin speech dataset collected by 8-channel circular microphone array for speech processing in conference scenarios. The dataset consists of 211 recorded meeting sessions, each containing 4 to 8 speakers, with a total length of 120 hours. This dataset aims to bridge the advanced research on multi-speaker processing and the practical application scenario in three aspects. With real recorded meetings, AISHELL-4 provides realistic acoustics and rich natural speech characteristics in conversation such as short pause, speech overlap, quick speaker turn, noise, etc. Meanwhile, the accurate transcription and speaker voice activity are provided for each meeting in AISHELL-4. This allows the researchers to explore different aspects in meeting processing, ranging from individual tasks such as speech front-end processing, speech recognition and speaker diarization, to multi-modality modeling and joint optimization of relevant tasks.
|
The AISHELL-4 is a sizable real-recorded Mandarin speech dataset collected by 8-channel circular microphone array for speech processing in conference scenarios. The dataset consists of 211 recorded meeting sessions, each containing 4 to 8 speakers, with a total length of 120 hours. This dataset aims to bridge the advanced research on multi-speaker processing and the practical application scenario in three aspects. With real recorded meetings, AISHELL-4 provides realistic acoustics and rich natural speech characteristics in conversation such as short pause, speech overlap, quick speaker turn, noise, etc. Meanwhile, the accurate transcription and speaker voice activity are provided for each meeting in AISHELL-4. This allows the researchers to explore different aspects in meeting processing, ranging from individual tasks such as speech front-end processing, speech recognition and speaker diarization, to multi-modality modeling and joint optimization of relevant tasks.
|
||||||
|
|
||||||
(From [Open Speech and Language Resources](https://www.openslr.org/111/))
|
(From [Open Speech and Language Resources](https://www.openslr.org/111/))
|
||||||
|
|
||||||
|
@ -29,7 +29,14 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig
|
from lhotse import (
|
||||||
|
CutSet,
|
||||||
|
Fbank,
|
||||||
|
FbankConfig,
|
||||||
|
LilcomChunkyWriter,
|
||||||
|
WhisperFbank,
|
||||||
|
WhisperFbankConfig,
|
||||||
|
)
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor, str2bool
|
from icefall.utils import get_executor, str2bool
|
||||||
@ -42,10 +49,12 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
|
def compute_fbank_aishell4(
|
||||||
|
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
|
||||||
|
):
|
||||||
src_dir = Path("data/manifests/aishell4")
|
src_dir = Path("data/manifests/aishell4")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(8, os.cpu_count())
|
||||||
|
|
||||||
dataset_parts = (
|
dataset_parts = (
|
||||||
"train_S",
|
"train_S",
|
||||||
@ -70,7 +79,12 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
|
|||||||
dataset_parts,
|
dataset_parts,
|
||||||
)
|
)
|
||||||
|
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
if whisper_fbank:
|
||||||
|
extractor = WhisperFbank(
|
||||||
|
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
with get_executor() as ex: # Initialize the executor only once.
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
for partition, m in manifests.items():
|
for partition, m in manifests.items():
|
||||||
@ -84,7 +98,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
|
|||||||
supervisions=m["supervisions"],
|
supervisions=m["supervisions"],
|
||||||
)
|
)
|
||||||
if "train" in partition and perturb_speed:
|
if "train" in partition and perturb_speed:
|
||||||
logging.info(f"Doing speed perturb")
|
logging.info("Doing speed perturb")
|
||||||
cut_set = (
|
cut_set = (
|
||||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
@ -95,7 +109,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
|
|||||||
# when an executor is specified, make more partitions
|
# when an executor is specified, make more partitions
|
||||||
num_jobs=num_jobs if ex is None else 80,
|
num_jobs=num_jobs if ex is None else 80,
|
||||||
executor=ex,
|
executor=ex,
|
||||||
storage_type=ChunkedLilcomHdf5Writer,
|
storage_type=LilcomChunkyWriter,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("About splitting cuts into smaller chunks")
|
logging.info("About splitting cuts into smaller chunks")
|
||||||
@ -121,7 +135,12 @@ def get_args():
|
|||||||
default=False,
|
default=False,
|
||||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-fbank",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -132,5 +151,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_aishell4(
|
compute_fbank_aishell4(
|
||||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
num_mel_bins=args.num_mel_bins,
|
||||||
|
perturb_speed=args.perturb_speed,
|
||||||
|
whisper_fbank=args.whisper_fbank,
|
||||||
)
|
)
|
||||||
|
@ -6,7 +6,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
|||||||
set -eou pipefail
|
set -eou pipefail
|
||||||
|
|
||||||
stage=-1
|
stage=-1
|
||||||
stop_stage=100
|
stop_stage=7
|
||||||
perturb_speed=true
|
perturb_speed=true
|
||||||
|
|
||||||
|
|
||||||
@ -76,11 +76,21 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||||
log "Stage 2: Process aishell4"
|
log "Stage 2: Compute fbank for aishell4"
|
||||||
if [ ! -f data/fbank/aishell4/.fbank.done ]; then
|
if [ ! -f data/fbank/aishell4/.fbank.done ]; then
|
||||||
mkdir -p data/fbank/aishell4
|
mkdir -p data/fbank
|
||||||
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed}
|
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed}
|
||||||
touch data/fbank/aishell4/.fbank.done
|
touch data/fbank/.fbank.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
whisper_mel_bins=80
|
||||||
|
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
|
||||||
|
log "Stage 20: Compute whisper fbank for aishell4"
|
||||||
|
if [ ! -f data/fbank/aishell4/.fbank.done ]; then
|
||||||
|
mkdir -p data/fbank
|
||||||
|
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||||
|
touch data/fbank/.fbank.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@ -106,16 +116,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
log "Stage 5: Compute fbank for aishell4"
|
log "Stage 5: Prepare char based lang"
|
||||||
if [ ! -f data/fbank/.aishell4.done ]; then
|
|
||||||
mkdir -p data/fbank
|
|
||||||
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed}
|
|
||||||
touch data/fbank/.aishell4.done
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
|
||||||
log "Stage 6: Prepare char based lang"
|
|
||||||
lang_char_dir=data/lang_char
|
lang_char_dir=data/lang_char
|
||||||
mkdir -p $lang_char_dir
|
mkdir -p $lang_char_dir
|
||||||
|
|
||||||
|
@ -29,7 +29,14 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
from lhotse import (
|
||||||
|
CutSet,
|
||||||
|
Fbank,
|
||||||
|
FbankConfig,
|
||||||
|
LilcomChunkyWriter,
|
||||||
|
WhisperFbank,
|
||||||
|
WhisperFbankConfig,
|
||||||
|
)
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor, str2bool
|
from icefall.utils import get_executor, str2bool
|
||||||
@ -42,10 +49,12 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False):
|
def compute_fbank_alimeeting(
|
||||||
|
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
|
||||||
|
):
|
||||||
src_dir = Path("data/manifests/alimeeting")
|
src_dir = Path("data/manifests/alimeeting")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(8, os.cpu_count())
|
||||||
|
|
||||||
dataset_parts = (
|
dataset_parts = (
|
||||||
"train",
|
"train",
|
||||||
@ -53,7 +62,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
|
|||||||
"test",
|
"test",
|
||||||
)
|
)
|
||||||
|
|
||||||
prefix = "alimeeting"
|
prefix = "alimeeting-far"
|
||||||
suffix = "jsonl.gz"
|
suffix = "jsonl.gz"
|
||||||
manifests = read_manifests_if_cached(
|
manifests = read_manifests_if_cached(
|
||||||
dataset_parts=dataset_parts,
|
dataset_parts=dataset_parts,
|
||||||
@ -70,7 +79,12 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
|
|||||||
dataset_parts,
|
dataset_parts,
|
||||||
)
|
)
|
||||||
|
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
if whisper_fbank:
|
||||||
|
extractor = WhisperFbank(
|
||||||
|
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
with get_executor() as ex: # Initialize the executor only once.
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
for partition, m in manifests.items():
|
for partition, m in manifests.items():
|
||||||
@ -83,7 +97,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
|
|||||||
supervisions=m["supervisions"],
|
supervisions=m["supervisions"],
|
||||||
)
|
)
|
||||||
if "train" in partition and perturb_speed:
|
if "train" in partition and perturb_speed:
|
||||||
logging.info(f"Doing speed perturb")
|
logging.info("Doing speed perturb")
|
||||||
cut_set = (
|
cut_set = (
|
||||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
@ -121,7 +135,12 @@ def get_args():
|
|||||||
default=False,
|
default=False,
|
||||||
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-fbank",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Use the Whisper Fbank feature extractor. Default: False.",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -132,5 +151,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_alimeeting(
|
compute_fbank_alimeeting(
|
||||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
num_mel_bins=args.num_mel_bins,
|
||||||
|
perturb_speed=args.perturb_speed,
|
||||||
|
whisper_fbank=args.whisper_fbank,
|
||||||
)
|
)
|
||||||
|
@ -6,7 +6,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
|||||||
set -eou pipefail
|
set -eou pipefail
|
||||||
|
|
||||||
stage=-1
|
stage=-1
|
||||||
stop_stage=100
|
stop_stage=7
|
||||||
perturb_speed=true
|
perturb_speed=true
|
||||||
|
|
||||||
# We assume dl_dir (download dir) contains the following
|
# We assume dl_dir (download dir) contains the following
|
||||||
@ -15,7 +15,7 @@ perturb_speed=true
|
|||||||
#
|
#
|
||||||
# - $dl_dir/alimeeting
|
# - $dl_dir/alimeeting
|
||||||
# This directory contains the following files downloaded from
|
# This directory contains the following files downloaded from
|
||||||
# https://openslr.org/62/
|
# https://openslr.org/119/
|
||||||
#
|
#
|
||||||
# - Train_Ali_far.tar.gz
|
# - Train_Ali_far.tar.gz
|
||||||
# - Train_Ali_near.tar.gz
|
# - Train_Ali_near.tar.gz
|
||||||
@ -66,10 +66,21 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||||
log "Stage 2: Process alimeeting"
|
log "Stage 2: compute fbank for alimeeting"
|
||||||
if [ ! -f data/fbank/alimeeting/.fbank.done ]; then
|
if [ ! -f data/fbank/.fbank.done ]; then
|
||||||
mkdir -p data/fbank/alimeeting
|
mkdir -p data/fbank
|
||||||
./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed}
|
./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed}
|
||||||
|
touch data/fbank/.fbank.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
whisper_mel_bins=80
|
||||||
|
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
|
||||||
|
log "Stage 20: compute whisper fbank for alimeeting"
|
||||||
|
if [ ! -f data/fbank/.fbank.done ]; then
|
||||||
|
mkdir -p data/fbank
|
||||||
|
./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||||
|
touch data/fbank/.fbank.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@ -95,16 +106,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
log "Stage 5: Compute fbank for alimeeting"
|
log "Stage 5: Prepare char based lang"
|
||||||
if [ ! -f data/fbank/.alimeeting.done ]; then
|
|
||||||
mkdir -p data/fbank
|
|
||||||
./local/compute_fbank_alimeeting.py --perturb-speed True
|
|
||||||
touch data/fbank/.alimeeting.done
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
|
||||||
log "Stage 6: Prepare char based lang"
|
|
||||||
lang_char_dir=data/lang_char
|
lang_char_dir=data/lang_char
|
||||||
mkdir -p $lang_char_dir
|
mkdir -p $lang_char_dir
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ use_gss=true # Use GSS-based enhancement with MDM setting
|
|||||||
#
|
#
|
||||||
# - $dl_dir/alimeeting
|
# - $dl_dir/alimeeting
|
||||||
# This directory contains the following files downloaded from
|
# This directory contains the following files downloaded from
|
||||||
# https://openslr.org/62/
|
# https://openslr.org/119/
|
||||||
#
|
#
|
||||||
# - Train_Ali_far.tar.gz
|
# - Train_Ali_far.tar.gz
|
||||||
# - Train_Ali_near.tar.gz
|
# - Train_Ali_near.tar.gz
|
||||||
|
@ -70,6 +70,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
@ -851,9 +852,7 @@ def train_one_epoch(
|
|||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
|
@ -69,6 +69,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
|
|
||||||
@ -842,9 +843,7 @@ def train_one_epoch(
|
|||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
|
@ -75,6 +75,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -1138,9 +1139,7 @@ def train_one_epoch(
|
|||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
|
@ -75,6 +75,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
@ -1129,9 +1130,7 @@ def train_one_epoch(
|
|||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
raise RuntimeError(
|
raise_grad_scale_is_too_small_error(cur_grad_scale)
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
|
12
egs/audioset/AT/README.md
Normal file
12
egs/audioset/AT/README.md
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
# Introduction
|
||||||
|
|
||||||
|
This is an audio tagging recipe for [Audioset](https://research.google.com/audioset/#/). It aims at predicting the sound events of an audio clip.
|
||||||
|
|
||||||
|
[./RESULTS.md](./RESULTS.md) contains the latest results.
|
||||||
|
|
||||||
|
|
||||||
|
# Zipformer
|
||||||
|
|
||||||
|
| Encoder | Feature type |
|
||||||
|
| --------| -------------|
|
||||||
|
| Zipformer | Frame level fbank|
|
95
egs/audioset/AT/RESULTS.md
Normal file
95
egs/audioset/AT/RESULTS.md
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
## Results
|
||||||
|
|
||||||
|
### zipformer
|
||||||
|
See <https://github.com/k2-fsa/icefall/pull/1421> for more details
|
||||||
|
|
||||||
|
[zipformer](./zipformer)
|
||||||
|
|
||||||
|
#### normal-scaled model, number of model parameters: 65549011, i.e., 65.55 M
|
||||||
|
|
||||||
|
You can find a pretrained model, training logs, decoding logs, and decoding results at:
|
||||||
|
<https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12#/>
|
||||||
|
|
||||||
|
The model achieves the following mean averaged precision on AudioSet:
|
||||||
|
|
||||||
|
| Model | mAP |
|
||||||
|
| ------ | ------- |
|
||||||
|
| Zipformer-AT | 45.1 |
|
||||||
|
|
||||||
|
The training command is:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export CUDA_VISIBLE_DEVICES="4,5,6,7"
|
||||||
|
subset=full
|
||||||
|
|
||||||
|
python zipformer/train.py \
|
||||||
|
--world-size 4 \
|
||||||
|
--num-epochs 50 \
|
||||||
|
--exp-dir zipformer/exp_at_as_${subset} \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--num-events 527 \
|
||||||
|
--audioset-subset $subset \
|
||||||
|
--max-duration 1000 \
|
||||||
|
--enable-musan True \
|
||||||
|
--master-port 13455
|
||||||
|
```
|
||||||
|
|
||||||
|
The evaluation command is:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python zipformer/evaluate.py \
|
||||||
|
--epoch 32 \
|
||||||
|
--avg 8 \
|
||||||
|
--exp-dir zipformer/exp_at_as_full \
|
||||||
|
--max-duration 500
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
#### small-scaled model, number of model parameters: 22125218, i.e., 22.13 M
|
||||||
|
|
||||||
|
You can find a pretrained model, training logs, decoding logs, and decoding results at:
|
||||||
|
<https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-small-2024-04-23#/>
|
||||||
|
|
||||||
|
The model achieves the following mean averaged precision on AudioSet:
|
||||||
|
|
||||||
|
| Model | mAP |
|
||||||
|
| ------ | ------- |
|
||||||
|
| Zipformer-S-AT | 45.1 |
|
||||||
|
|
||||||
|
The training command is:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export CUDA_VISIBLE_DEVICES="4,5,6,7"
|
||||||
|
subset=full
|
||||||
|
|
||||||
|
python zipformer/train.py \
|
||||||
|
--world-size 4 \
|
||||||
|
--num-epochs 50 \
|
||||||
|
--exp-dir zipformer/exp_small_at_as_${subset} \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--num-events 527 \
|
||||||
|
--num-encoder-layers 2,2,2,2,2,2 \
|
||||||
|
--feedforward-dim 512,768,768,768,768,768 \
|
||||||
|
--encoder-dim 192,256,256,256,256,256 \
|
||||||
|
--encoder-unmasked-dim 192,192,192,192,192,192 \
|
||||||
|
--audioset-subset $subset \
|
||||||
|
--max-duration 1200 \
|
||||||
|
--enable-musan True \
|
||||||
|
--master-port 13455
|
||||||
|
```
|
||||||
|
|
||||||
|
The evaluation command is:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python zipformer/evaluate.py \
|
||||||
|
--epoch 31 \
|
||||||
|
--avg 4 \
|
||||||
|
--num-encoder-layers 2,2,2,2,2,2 \
|
||||||
|
--feedforward-dim 512,768,768,768,768,768 \
|
||||||
|
--encoder-dim 192,256,256,256,256,256 \
|
||||||
|
--encoder-unmasked-dim 192,192,192,192,192,192 \
|
||||||
|
--exp-dir zipformer/exp_small_at_as_full \
|
||||||
|
--max-duration 500
|
||||||
|
```
|
1
egs/audioset/AT/local/compute_fbank_musan.py
Symbolic link
1
egs/audioset/AT/local/compute_fbank_musan.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/local/compute_fbank_musan.py
|
177
egs/audioset/AT/local/generate_audioset_manifest.py
Normal file
177
egs/audioset/AT/local/generate_audioset_manifest.py
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file generates the manifest and computes the fbank features for AudioSet
|
||||||
|
dataset. The generated manifests and features are stored in data/fbank.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import glob
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||||
|
from lhotse.audio import Recording
|
||||||
|
from lhotse.cut import MonoCut
|
||||||
|
from lhotse.supervision import SupervisionSegment
|
||||||
|
|
||||||
|
from icefall.utils import get_executor
|
||||||
|
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
|
def get_ID_mapping(csv_file):
|
||||||
|
# get a mapping between class ID and class name
|
||||||
|
mapping = {}
|
||||||
|
with open(csv_file, "r") as fin:
|
||||||
|
reader = csv.reader(fin, delimiter=",")
|
||||||
|
for i, row in enumerate(reader):
|
||||||
|
if i == 0:
|
||||||
|
continue
|
||||||
|
mapping[row[1]] = row[0]
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
|
def parse_csv(csv_file: str, id_mapping: Dict):
|
||||||
|
# The content of the csv file shoud be something like this
|
||||||
|
# ------------------------------------------------------
|
||||||
|
# filename label
|
||||||
|
# dataset/AudioSet/balanced/xxxx.wav 0;451
|
||||||
|
# dataset/AudioSet/balanced/xxxy.wav 375
|
||||||
|
# ------------------------------------------------------
|
||||||
|
|
||||||
|
def name2id(names):
|
||||||
|
ids = [id_mapping[name] for name in names.split(",")]
|
||||||
|
return ";".join(ids)
|
||||||
|
|
||||||
|
mapping = {}
|
||||||
|
with open(csv_file, "r") as fin:
|
||||||
|
reader = csv.reader(fin, delimiter=" ")
|
||||||
|
for i, row in enumerate(reader):
|
||||||
|
if i <= 2:
|
||||||
|
continue
|
||||||
|
key = row[0].replace(",", "")
|
||||||
|
mapping[key] = name2id(row[-1])
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--dataset-dir", type=str, default="downloads/audioset")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--split",
|
||||||
|
type=str,
|
||||||
|
default="balanced",
|
||||||
|
choices=["balanced", "unbalanced", "eval"],
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--feat-output-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/fbank",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
dataset_dir = args.dataset_dir
|
||||||
|
split = args.split
|
||||||
|
feat_output_dir = args.feat_output_dir
|
||||||
|
|
||||||
|
num_jobs = min(15, os.cpu_count())
|
||||||
|
num_mel_bins = 80
|
||||||
|
|
||||||
|
if split in ["balanced", "unbalanced"]:
|
||||||
|
csv_file = f"{dataset_dir}/{split}_train_segments.csv"
|
||||||
|
elif split == "eval":
|
||||||
|
csv_file = f"{dataset_dir}/eval_segments.csv"
|
||||||
|
else:
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
class_indices_csv = f"{dataset_dir}/class_labels_indices.csv"
|
||||||
|
id_mapping = get_ID_mapping(class_indices_csv)
|
||||||
|
labels = parse_csv(csv_file, id_mapping)
|
||||||
|
|
||||||
|
audio_files = glob.glob(f"{dataset_dir}/{split}/*.wav")
|
||||||
|
|
||||||
|
new_cuts = []
|
||||||
|
for i, audio in enumerate(audio_files):
|
||||||
|
cut_id = audio.split("/")[-1].split("_")[0]
|
||||||
|
recording = Recording.from_file(audio, cut_id)
|
||||||
|
cut = MonoCut(
|
||||||
|
id=cut_id,
|
||||||
|
start=0.0,
|
||||||
|
duration=recording.duration,
|
||||||
|
channel=0,
|
||||||
|
recording=recording,
|
||||||
|
)
|
||||||
|
supervision = SupervisionSegment(
|
||||||
|
id=cut_id,
|
||||||
|
recording_id=cut.recording.id,
|
||||||
|
start=0.0,
|
||||||
|
channel=0,
|
||||||
|
duration=cut.duration,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
supervision.audio_event = labels[cut_id]
|
||||||
|
except KeyError:
|
||||||
|
logging.info(f"No labels found for {cut_id}.")
|
||||||
|
continue
|
||||||
|
cut.supervisions = [supervision]
|
||||||
|
new_cuts.append(cut)
|
||||||
|
|
||||||
|
if i % 100 == 0 and i:
|
||||||
|
logging.info(f"Processed {i} cuts until now.")
|
||||||
|
|
||||||
|
cuts = CutSet.from_cuts(new_cuts)
|
||||||
|
|
||||||
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
|
logging.info(f"Computing fbank features for {split}")
|
||||||
|
with get_executor() as ex:
|
||||||
|
cuts = cuts.compute_and_store_features(
|
||||||
|
extractor=extractor,
|
||||||
|
storage_path=f"{feat_output_dir}/{split}_feats",
|
||||||
|
num_jobs=num_jobs if ex is None else 80,
|
||||||
|
executor=ex,
|
||||||
|
storage_type=LilcomChunkyWriter,
|
||||||
|
)
|
||||||
|
|
||||||
|
manifest_output_dir = feat_output_dir + "/" + f"cuts_audioset_{split}.jsonl.gz"
|
||||||
|
|
||||||
|
logging.info(f"Storing the manifest to {manifest_output_dir}")
|
||||||
|
cuts.to_jsonl(manifest_output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
104
egs/audioset/AT/prepare.sh
Executable file
104
egs/audioset/AT/prepare.sh
Executable file
@ -0,0 +1,104 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||||
|
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||||
|
|
||||||
|
set -eou pipefail
|
||||||
|
|
||||||
|
# run step 0 to step 5 by default
|
||||||
|
stage=-1
|
||||||
|
stop_stage=4
|
||||||
|
|
||||||
|
dl_dir=$PWD/download
|
||||||
|
|
||||||
|
# we assume that you have your downloaded the AudioSet and placed
|
||||||
|
# it under $dl_dir/audioset, the folder structure should look like
|
||||||
|
# this:
|
||||||
|
# - $dl_dir/audioset
|
||||||
|
# - balanced
|
||||||
|
# - eval
|
||||||
|
# - unbalanced
|
||||||
|
# If you haven't downloaded the AudioSet, please refer to
|
||||||
|
# https://github.com/RicherMans/SAT/blob/main/datasets/audioset/1_download_audioset.sh.
|
||||||
|
|
||||||
|
. shared/parse_options.sh || exit 1
|
||||||
|
|
||||||
|
# All files generated by this script are saved in "data".
|
||||||
|
# You can safely remove "data" and rerun this script to regenerate it.
|
||||||
|
mkdir -p data
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
log "Running prepare.sh"
|
||||||
|
|
||||||
|
log "dl_dir: $dl_dir"
|
||||||
|
|
||||||
|
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||||
|
log "Stage 0: Download the necessary csv files"
|
||||||
|
if [ ! -e $dl_dir/audioset/.csv.done]; then
|
||||||
|
wget --continue "http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/class_labels_indices.csv" -O "${dl_dir}/audioset/class_labels_indices.csv"
|
||||||
|
wget --continue http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/balanced_train_segments.csv -O "${dl_dir}/audioset/balanced_train_segments.csv"
|
||||||
|
wget --continue http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/eval_segments.csv -O "${dl_dir}/audioset/eval_segments.csv"
|
||||||
|
touch $dl_dir/audioset/.csv.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||||
|
log "Stage 0: Construct the audioset manifest and compute the fbank features for balanced set"
|
||||||
|
fbank_dir=data/fbank
|
||||||
|
if [! -e $fbank_dir/.balanced.done]; then
|
||||||
|
python local/generate_audioset_manifest.py \
|
||||||
|
--dataset-dir $dl_dir/audioset \
|
||||||
|
--split balanced \
|
||||||
|
--feat-output-dir $fbank_dir
|
||||||
|
touch $fbank_dir/.balanced.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||||
|
log "Stage 1: Construct the audioset manifest and compute the fbank features for unbalanced set"
|
||||||
|
fbank_dir=data/fbank
|
||||||
|
if [! -e $fbank_dir/.unbalanced.done]; then
|
||||||
|
python local/generate_audioset_manifest.py \
|
||||||
|
--dataset-dir $dl_dir/audioset \
|
||||||
|
--split unbalanced \
|
||||||
|
--feat-output-dir $fbank_dir
|
||||||
|
touch $fbank_dir/.unbalanced.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||||
|
log "Stage 2: Construct the audioset manifest and compute the fbank features for eval set"
|
||||||
|
fbank_dir=data/fbank
|
||||||
|
if [! -e $fbank_dir/.eval.done]; then
|
||||||
|
python local/generate_audioset_manifest.py \
|
||||||
|
--dataset-dir $dl_dir/audioset \
|
||||||
|
--split eval \
|
||||||
|
--feat-output-dir $fbank_dir
|
||||||
|
touch $fbank_dir/.eval.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
|
log "Stage 3: Prepare musan manifest"
|
||||||
|
# We assume that you have downloaded the musan corpus
|
||||||
|
# to $dl_dir/musan
|
||||||
|
mkdir -p data/manifests
|
||||||
|
if [ ! -e data/manifests/.musan.done ]; then
|
||||||
|
lhotse prepare musan $dl_dir/musan data/manifests
|
||||||
|
touch data/manifests/.musan.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
|
log "Stage 4: Compute fbank for musan"
|
||||||
|
mkdir -p data/fbank
|
||||||
|
if [ ! -e data/fbank/.musan.done ]; then
|
||||||
|
./local/compute_fbank_musan.py
|
||||||
|
touch data/fbank/.musan.done
|
||||||
|
fi
|
||||||
|
fi
|
1
egs/audioset/AT/shared
Symbolic link
1
egs/audioset/AT/shared
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../icefall/shared
|
@ -1,7 +1,6 @@
|
|||||||
# Copyright 2021 Piotr Żelasko
|
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
||||||
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -15,7 +14,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
@ -26,12 +24,12 @@ from typing import Any, Dict, Optional
|
|||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||||
|
AudioTaggingDataset,
|
||||||
CutConcatenate,
|
CutConcatenate,
|
||||||
CutMix,
|
CutMix,
|
||||||
DynamicBucketingSampler,
|
DynamicBucketingSampler,
|
||||||
K2SpeechRecognitionDataset,
|
|
||||||
PrecomputedFeatures,
|
PrecomputedFeatures,
|
||||||
SingleCutSampler,
|
SimpleCutSampler,
|
||||||
SpecAugment,
|
SpecAugment,
|
||||||
)
|
)
|
||||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||||
@ -52,14 +50,12 @@ class _SeedWorkers:
|
|||||||
fix_random_seed(self.seed + worker_id)
|
fix_random_seed(self.seed + worker_id)
|
||||||
|
|
||||||
|
|
||||||
class CommonVoiceAsrDataModule:
|
class AudioSetATDatamodule:
|
||||||
"""
|
"""
|
||||||
DataModule for k2 ASR experiments.
|
DataModule for k2 audio tagging (AT) experiments.
|
||||||
It assumes there is always one train and valid dataloader,
|
|
||||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
|
||||||
and test-other).
|
|
||||||
|
|
||||||
It contains all the common data pipeline modules used in ASR
|
|
||||||
|
It contains all the common data pipeline modules used in AT
|
||||||
experiments, e.g.:
|
experiments, e.g.:
|
||||||
- dynamic batch size,
|
- dynamic batch size,
|
||||||
- bucketing samplers,
|
- bucketing samplers,
|
||||||
@ -67,7 +63,7 @@ class CommonVoiceAsrDataModule:
|
|||||||
- augmentation,
|
- augmentation,
|
||||||
- on-the-fly feature extraction
|
- on-the-fly feature extraction
|
||||||
|
|
||||||
This class should be derived for specific corpora used in ASR tasks.
|
This class should be derived for specific corpora used in AT tasks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, args: argparse.Namespace):
|
def __init__(self, args: argparse.Namespace):
|
||||||
@ -76,7 +72,7 @@ class CommonVoiceAsrDataModule:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||||
group = parser.add_argument_group(
|
group = parser.add_argument_group(
|
||||||
title="ASR data related options",
|
title="AT data related options",
|
||||||
description="These options are used for the preparation of "
|
description="These options are used for the preparation of "
|
||||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||||
"effective batch sizes, sampling strategies, applied data "
|
"effective batch sizes, sampling strategies, applied data "
|
||||||
@ -84,22 +80,17 @@ class CommonVoiceAsrDataModule:
|
|||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--language",
|
"--audioset-subset",
|
||||||
type=str,
|
type=str,
|
||||||
default="fr",
|
default="balanced",
|
||||||
help="""Language of Common Voice""",
|
choices=["balanced", "full"],
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--cv-manifest-dir",
|
|
||||||
type=Path,
|
|
||||||
default=Path("data/fr/fbank"),
|
|
||||||
help="Path to directory with CommonVoice train/dev/test cuts.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--manifest-dir",
|
"--manifest-dir",
|
||||||
type=Path,
|
type=Path,
|
||||||
default=Path("data/fbank"),
|
default=Path("data/fbank"),
|
||||||
help="Path to directory with train/valid/test cuts.",
|
help="Path to directory with audioset train/test cuts.",
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--max-duration",
|
"--max-duration",
|
||||||
@ -218,7 +209,7 @@ class CommonVoiceAsrDataModule:
|
|||||||
self,
|
self,
|
||||||
cuts_train: CutSet,
|
cuts_train: CutSet,
|
||||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||||
) -> DataLoader:
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
cuts_train:
|
cuts_train:
|
||||||
@ -232,7 +223,7 @@ class CommonVoiceAsrDataModule:
|
|||||||
logging.info("About to get Musan cuts")
|
logging.info("About to get Musan cuts")
|
||||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||||
transforms.append(
|
transforms.append(
|
||||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Disable MUSAN")
|
logging.info("Disable MUSAN")
|
||||||
@ -278,7 +269,7 @@ class CommonVoiceAsrDataModule:
|
|||||||
logging.info("Disable SpecAugment")
|
logging.info("Disable SpecAugment")
|
||||||
|
|
||||||
logging.info("About to create train dataset")
|
logging.info("About to create train dataset")
|
||||||
train = K2SpeechRecognitionDataset(
|
train = AudioTaggingDataset(
|
||||||
input_strategy=eval(self.args.input_strategy)(),
|
input_strategy=eval(self.args.input_strategy)(),
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_transforms=input_transforms,
|
input_transforms=input_transforms,
|
||||||
@ -296,7 +287,7 @@ class CommonVoiceAsrDataModule:
|
|||||||
# to be strict (e.g. could be randomized)
|
# to be strict (e.g. could be randomized)
|
||||||
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||||
# Drop feats to be on the safe side.
|
# Drop feats to be on the safe side.
|
||||||
train = K2SpeechRecognitionDataset(
|
train = AudioTaggingDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||||
input_transforms=input_transforms,
|
input_transforms=input_transforms,
|
||||||
@ -310,16 +301,15 @@ class CommonVoiceAsrDataModule:
|
|||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
num_buckets=self.args.num_buckets,
|
num_buckets=self.args.num_buckets,
|
||||||
buffer_size=self.args.num_buckets * 2000,
|
|
||||||
shuffle_buffer_size=self.args.num_buckets * 5000,
|
|
||||||
drop_last=self.args.drop_last,
|
drop_last=self.args.drop_last,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SingleCutSampler.")
|
logging.info("Using SimpleCutSampler.")
|
||||||
train_sampler = SingleCutSampler(
|
train_sampler = SimpleCutSampler(
|
||||||
cuts_train,
|
cuts_train,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
|
drop_last=self.args.drop_last,
|
||||||
)
|
)
|
||||||
logging.info("About to create train dataloader")
|
logging.info("About to create train dataloader")
|
||||||
|
|
||||||
@ -354,13 +344,13 @@ class CommonVoiceAsrDataModule:
|
|||||||
|
|
||||||
logging.info("About to create dev dataset")
|
logging.info("About to create dev dataset")
|
||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
validate = K2SpeechRecognitionDataset(
|
validate = AudioTaggingDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
validate = K2SpeechRecognitionDataset(
|
validate = AudioTaggingDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
@ -382,10 +372,12 @@ class CommonVoiceAsrDataModule:
|
|||||||
|
|
||||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||||
logging.debug("About to create test dataset")
|
logging.debug("About to create test dataset")
|
||||||
test = K2SpeechRecognitionDataset(
|
test = AudioTaggingDataset(
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
input_strategy=(
|
||||||
if self.args.on_the_fly_feats
|
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||||
else eval(self.args.input_strategy)(),
|
if self.args.on_the_fly_feats
|
||||||
|
else eval(self.args.input_strategy)()
|
||||||
|
),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
sampler = DynamicBucketingSampler(
|
sampler = DynamicBucketingSampler(
|
||||||
@ -403,22 +395,28 @@ class CommonVoiceAsrDataModule:
|
|||||||
return test_dl
|
return test_dl
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_cuts(self) -> CutSet:
|
def audioset_train_cuts(self) -> CutSet:
|
||||||
logging.info("About to get train cuts")
|
logging.info("About to get the audioset training cuts.")
|
||||||
return load_manifest_lazy(
|
balanced_cuts = load_manifest_lazy(
|
||||||
self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_train.jsonl.gz"
|
self.args.manifest_dir / "cuts_audioset_balanced.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
if self.args.audioset_subset == "full":
|
||||||
|
unbalanced_cuts = load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "cuts_audioset_unbalanced.jsonl.gz"
|
||||||
|
)
|
||||||
|
cuts = CutSet.mux(
|
||||||
|
balanced_cuts,
|
||||||
|
unbalanced_cuts,
|
||||||
|
weights=[20000, 2000000],
|
||||||
|
stop_early=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cuts = balanced_cuts
|
||||||
|
return cuts
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def dev_cuts(self) -> CutSet:
|
def audioset_eval_cuts(self) -> CutSet:
|
||||||
logging.info("About to get dev cuts")
|
logging.info("About to get audioset eval cuts")
|
||||||
return load_manifest_lazy(
|
return load_manifest_lazy(
|
||||||
self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_dev.jsonl.gz"
|
self.args.manifest_dir / "cuts_audioset_eval.jsonl.gz"
|
||||||
)
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def test_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get test cuts")
|
|
||||||
return load_manifest_lazy(
|
|
||||||
self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_test.jsonl.gz"
|
|
||||||
)
|
)
|
1
egs/audioset/AT/zipformer/encoder_interface.py
Symbolic link
1
egs/audioset/AT/zipformer/encoder_interface.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/transducer_stateless/encoder_interface.py
|
327
egs/audioset/AT/zipformer/evaluate.py
Normal file
327
egs/audioset/AT/zipformer/evaluate.py
Normal file
@ -0,0 +1,327 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES="0"
|
||||||
|
|
||||||
|
./zipformer/evaluate.py \
|
||||||
|
--epoch 50 \
|
||||||
|
--avg 10 \
|
||||||
|
--exp-dir zipformer/exp \
|
||||||
|
--max-duration 1000
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from at_datamodule import AudioSetATDatamodule
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sklearn.metrics import average_precision_score
|
||||||
|
except:
|
||||||
|
raise ImportError(f"Please run\n" "pip3 install -U scikit-learn")
|
||||||
|
from train import add_model_arguments, get_model, get_params, str2multihot
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="zipformer/exp",
|
||||||
|
help="The experiment dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def inference_one_batch(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
batch: dict,
|
||||||
|
):
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
feature = batch["inputs"]
|
||||||
|
assert feature.ndim == 3, feature.shape
|
||||||
|
|
||||||
|
feature = feature.to(device)
|
||||||
|
# at entry, feature is (N, T, C)
|
||||||
|
|
||||||
|
supervisions = batch["supervisions"]
|
||||||
|
audio_event = supervisions["audio_event"]
|
||||||
|
|
||||||
|
label, _ = str2multihot(audio_event)
|
||||||
|
label = label.detach().cpu()
|
||||||
|
|
||||||
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
|
encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
|
||||||
|
|
||||||
|
audio_logits = model.forward_audio_tagging(encoder_out, encoder_out_lens)
|
||||||
|
# convert to probabilities between 0-1
|
||||||
|
audio_logits = audio_logits.sigmoid().detach().cpu()
|
||||||
|
|
||||||
|
return audio_logits, label
|
||||||
|
|
||||||
|
|
||||||
|
def decode_dataset(
|
||||||
|
dl: torch.utils.data.DataLoader,
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
) -> Dict:
|
||||||
|
num_cuts = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_batches = len(dl)
|
||||||
|
except TypeError:
|
||||||
|
num_batches = "?"
|
||||||
|
|
||||||
|
all_logits = []
|
||||||
|
all_labels = []
|
||||||
|
|
||||||
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
num_cuts += len(cut_ids)
|
||||||
|
|
||||||
|
audio_logits, labels = inference_one_batch(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
batch=batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_logits.append(audio_logits)
|
||||||
|
all_labels.append(labels)
|
||||||
|
|
||||||
|
if batch_idx % 20 == 1:
|
||||||
|
logging.info(f"Processed {num_cuts} cuts already.")
|
||||||
|
logging.info("Finish collecting audio logits")
|
||||||
|
|
||||||
|
return all_logits, all_labels
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
AudioSetATDatamodule.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
params.res_dir = params.exp_dir / "inference_audio_tagging"
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||||
|
else:
|
||||||
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
|
if params.use_averaged_model:
|
||||||
|
params.suffix += "-use-averaged-model"
|
||||||
|
|
||||||
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
|
logging.info("Evaluation started")
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
|
||||||
|
model = get_model(params)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device), strict=False
|
||||||
|
)
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device), strict=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
args.return_cuts = True
|
||||||
|
audioset = AudioSetATDatamodule(args)
|
||||||
|
|
||||||
|
audioset_cuts = audioset.audioset_eval_cuts()
|
||||||
|
|
||||||
|
audioset_dl = audioset.valid_dataloaders(audioset_cuts)
|
||||||
|
|
||||||
|
test_sets = ["audioset_eval"]
|
||||||
|
|
||||||
|
logits, labels = decode_dataset(
|
||||||
|
dl=audioset_dl,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = torch.cat(logits, dim=0).squeeze(dim=1).detach().numpy()
|
||||||
|
labels = torch.cat(labels, dim=0).long().detach().numpy()
|
||||||
|
|
||||||
|
# compute the metric
|
||||||
|
mAP = average_precision_score(
|
||||||
|
y_true=labels,
|
||||||
|
y_score=logits,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(f"mAP for audioset eval is: {mAP}")
|
||||||
|
|
||||||
|
logging.info("Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
411
egs/audioset/AT/zipformer/export-onnx.py
Executable file
411
egs/audioset/AT/zipformer/export-onnx.py
Executable file
@ -0,0 +1,411 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang)
|
||||||
|
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script exports a transducer model from PyTorch to ONNX.
|
||||||
|
|
||||||
|
Usage of this script:
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
pushd $repo/exp
|
||||||
|
git lfs pull --include pretrained.pt
|
||||||
|
ln -s pretrained.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
python3 zipformer/export-onnx.py \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--use-averaged-model 0
|
||||||
|
|
||||||
|
pushd $repo/exp
|
||||||
|
mv model-epoch-99-avg-1.onnx model.onnx
|
||||||
|
mv model-epoch-99-avg-1.int8.onnx model.int8.onnx
|
||||||
|
popd
|
||||||
|
|
||||||
|
See ./onnx_pretrained.py
|
||||||
|
use the exported ONNX models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
import onnxoptimizer
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
|
from onnxsim import simplify
|
||||||
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
from zipformer import Zipformer2
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.utils import make_pad_mask, str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=28,
|
||||||
|
help="""It specifies the checkpoint to use for averaging.
|
||||||
|
Note: Epoch counts from 0.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="zipformer/exp",
|
||||||
|
help="""It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||||
|
"""Add meta data to an ONNX model. It is changed in-place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Filename of the ONNX model to be changed.
|
||||||
|
meta_data:
|
||||||
|
Key-value pairs.
|
||||||
|
"""
|
||||||
|
model = onnx.load(filename)
|
||||||
|
for key, value in meta_data.items():
|
||||||
|
meta = model.metadata_props.add()
|
||||||
|
meta.key = key
|
||||||
|
meta.value = value
|
||||||
|
|
||||||
|
onnx.save(model, filename)
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxAudioTagger(nn.Module):
|
||||||
|
"""A wrapper for Zipformer audio tagger"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, encoder: Zipformer2, encoder_embed: nn.Module, classifier: nn.Linear
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder:
|
||||||
|
A Zipformer encoder.
|
||||||
|
encoder_proj:
|
||||||
|
The projection layer for encoder from the joiner.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = encoder
|
||||||
|
self.encoder_embed = encoder_embed
|
||||||
|
self.classifier = classifier
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_lens: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Please see the help information of Zipformer.forward
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A 3-D tensor of shape (N, T, C)
|
||||||
|
x_lens:
|
||||||
|
A 1-D tensor of shape (N,). Its dtype is torch.int64
|
||||||
|
Returns:
|
||||||
|
Return a tensor containing:
|
||||||
|
- probs, A 2-D tensor of shape (N, num_classes)
|
||||||
|
|
||||||
|
"""
|
||||||
|
x, x_lens = self.encoder_embed(x, x_lens)
|
||||||
|
src_key_padding_mask = make_pad_mask(x_lens)
|
||||||
|
x = x.permute(1, 0, 2)
|
||||||
|
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||||
|
encoder_out = encoder_out.permute(1, 0, 2) # (N,T,C)
|
||||||
|
|
||||||
|
logits = self.classifier(encoder_out) # (N, T, num_classes)
|
||||||
|
# Note that this is slightly different from model.py for better
|
||||||
|
# support of onnx
|
||||||
|
logits = logits.mean(dim=1)
|
||||||
|
probs = logits.sigmoid()
|
||||||
|
return probs
|
||||||
|
|
||||||
|
|
||||||
|
def export_audio_tagging_model_onnx(
|
||||||
|
model: OnnxAudioTagger,
|
||||||
|
filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given encoder model to ONNX format.
|
||||||
|
The exported model has two inputs:
|
||||||
|
|
||||||
|
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
||||||
|
- x_lens, a tensor of shape (N,); dtype is torch.int64
|
||||||
|
|
||||||
|
and it has two outputs:
|
||||||
|
|
||||||
|
- encoder_out, a tensor of shape (N, T', joiner_dim)
|
||||||
|
- encoder_out_lens, a tensor of shape (N,)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
The input encoder model
|
||||||
|
filename:
|
||||||
|
The filename to save the exported ONNX model.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
x = torch.zeros(1, 200, 80, dtype=torch.float32)
|
||||||
|
x_lens = torch.tensor([200], dtype=torch.int64)
|
||||||
|
|
||||||
|
model = torch.jit.trace(model, (x, x_lens))
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
model,
|
||||||
|
(x, x_lens),
|
||||||
|
filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["x", "x_lens"],
|
||||||
|
output_names=["logits"],
|
||||||
|
dynamic_axes={
|
||||||
|
"x": {0: "N", 1: "T"},
|
||||||
|
"x_lens": {0: "N"},
|
||||||
|
"probs": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
meta_data = {
|
||||||
|
"model_type": "zipformer2",
|
||||||
|
"version": "1",
|
||||||
|
"model_author": "k2-fsa",
|
||||||
|
"comment": "zipformer2 audio tagger",
|
||||||
|
"url": "https://github.com/k2-fsa/icefall/tree/master/egs/audioset/AT/zipformer",
|
||||||
|
}
|
||||||
|
logging.info(f"meta_data: {meta_data}")
|
||||||
|
|
||||||
|
add_meta_data(filename=filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
def optimize_model(filename):
|
||||||
|
# see
|
||||||
|
# https://github.com/microsoft/onnxruntime/issues/1899#issuecomment-534806537
|
||||||
|
# and
|
||||||
|
# https://github.com/onnx/onnx/issues/582#issuecomment-937788108
|
||||||
|
# and
|
||||||
|
# https://github.com/onnx/optimizer/issues/110
|
||||||
|
# and
|
||||||
|
# https://qiita.com/Yossy_Hal/items/34f3b2aef2199baf7f5f
|
||||||
|
passes = ["eliminate_unused_initializer"]
|
||||||
|
onnx_model = onnx.load(filename)
|
||||||
|
onnx_model = onnxoptimizer.optimize(onnx_model, passes)
|
||||||
|
|
||||||
|
model_simp, check = simplify(onnx_model)
|
||||||
|
if check:
|
||||||
|
logging.info("Simplified the model!")
|
||||||
|
onnx_model = model_simp
|
||||||
|
else:
|
||||||
|
logging.info("Failed to simplify the model!")
|
||||||
|
|
||||||
|
onnx.save(onnx_model, filename)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_model(params)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to("cpu")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
|
||||||
|
|
||||||
|
model = OnnxAudioTagger(
|
||||||
|
encoder=model.encoder,
|
||||||
|
encoder_embed=model.encoder_embed,
|
||||||
|
classifier=model.classifier,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"total parameters: {model_num_param}")
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
suffix = f"iter-{params.iter}"
|
||||||
|
else:
|
||||||
|
suffix = f"epoch-{params.epoch}"
|
||||||
|
|
||||||
|
suffix += f"-avg-{params.avg}"
|
||||||
|
|
||||||
|
opset_version = 13
|
||||||
|
|
||||||
|
logging.info("Exporting audio tagging model")
|
||||||
|
model_filename = params.exp_dir / f"model-{suffix}.onnx"
|
||||||
|
export_audio_tagging_model_onnx(
|
||||||
|
model,
|
||||||
|
model_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
optimize_model(model_filename)
|
||||||
|
logging.info(f"Exported audio tagging model to {model_filename}")
|
||||||
|
|
||||||
|
# Generate int8 quantization models
|
||||||
|
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||||
|
|
||||||
|
logging.info("Generate int8 quantization models")
|
||||||
|
|
||||||
|
model_filename_int8 = params.exp_dir / f"model-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=model_filename,
|
||||||
|
model_output=model_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
optimize_model(model_filename_int8)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
340
egs/audioset/AT/zipformer/export.py
Executable file
340
egs/audioset/AT/zipformer/export.py
Executable file
@ -0,0 +1,340 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||||
|
# Zengwei Yao,
|
||||||
|
# Wei Kang,
|
||||||
|
# Xiaoyu Yang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# This script converts several saved checkpoints
|
||||||
|
# to a single one using model averaging.
|
||||||
|
"""
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
Note: This is an example for AudioSet dataset, if you are using different
|
||||||
|
dataset, you should change the argument values according to your dataset.
|
||||||
|
|
||||||
|
(1) Export to torchscript model using torch.jit.script()
|
||||||
|
|
||||||
|
./zipformer/export.py \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 9 \
|
||||||
|
--jit 1
|
||||||
|
|
||||||
|
It will generate a file `jit_script.pt` in the given `exp_dir`. You can later
|
||||||
|
load it by `torch.jit.load("jit_script.pt")`.
|
||||||
|
|
||||||
|
Check ./jit_pretrained.py for its usage.
|
||||||
|
|
||||||
|
Check https://github.com/k2-fsa/sherpa
|
||||||
|
and https://github.com/k2-fsa/sherpa-onnx
|
||||||
|
for how to use the exported models outside of icefall.
|
||||||
|
|
||||||
|
(2) Export `model.state_dict()`
|
||||||
|
|
||||||
|
./zipformer/export.py \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 9
|
||||||
|
|
||||||
|
|
||||||
|
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
||||||
|
load it by `icefall.checkpoint.load_checkpoint()`.
|
||||||
|
|
||||||
|
To use the generated file with `zipformer/evaluate.py`,
|
||||||
|
you can do:
|
||||||
|
|
||||||
|
cd /path/to/exp_dir
|
||||||
|
ln -s pretrained.pt epoch-9999.pt
|
||||||
|
|
||||||
|
cd /path/to/egs/audioset/AT
|
||||||
|
./zipformer/evaluate.py \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--use-averaged-model False \
|
||||||
|
--epoch 9999 \
|
||||||
|
--avg 1 \
|
||||||
|
--max-duration 600
|
||||||
|
|
||||||
|
Check ./pretrained.py for its usage.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.utils import make_pad_mask, str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=9,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="zipformer/exp",
|
||||||
|
help="""It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--jit",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""True to save a model after applying torch.jit.script.
|
||||||
|
It will generate a file named jit_script.pt.
|
||||||
|
Check ./jit_pretrained.py for how to use it.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderModel(nn.Module):
|
||||||
|
"""A wrapper for encoder and encoder_embed"""
|
||||||
|
|
||||||
|
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = encoder
|
||||||
|
self.encoder_embed = encoder_embed
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, features: Tensor, feature_lengths: Tensor
|
||||||
|
) -> Tuple[Tensor, Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
features: (N, T, C)
|
||||||
|
feature_lengths: (N,)
|
||||||
|
"""
|
||||||
|
x, x_lens = self.encoder_embed(features, feature_lengths)
|
||||||
|
|
||||||
|
src_key_padding_mask = make_pad_mask(x_lens)
|
||||||
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
|
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||||
|
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
|
return encoder_out, encoder_out_lens
|
||||||
|
|
||||||
|
|
||||||
|
class Classifier(nn.Module):
|
||||||
|
"""A wrapper for audio tagging classifier"""
|
||||||
|
|
||||||
|
def __init__(self, classifier: nn.Module) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.classifier = classifier
|
||||||
|
|
||||||
|
def forward(self, encoder_out: Tensor, encoder_out_lens: Tensor):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder_out:
|
||||||
|
A 3-D tensor of shape (N, T, C).
|
||||||
|
encoder_out_lens:
|
||||||
|
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||||
|
before padding.
|
||||||
|
"""
|
||||||
|
logits = self.classifier(encoder_out) # (N, T, num_classes)
|
||||||
|
padding_mask = make_pad_mask(encoder_out_lens)
|
||||||
|
logits[padding_mask] = 0
|
||||||
|
logits = logits.sum(dim=1) # mask the padding frames
|
||||||
|
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(
|
||||||
|
logits
|
||||||
|
) # normalize the logits
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_model(params)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
if params.jit is True:
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
|
# it here.
|
||||||
|
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||||
|
# torch scriptabe.
|
||||||
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
|
|
||||||
|
model.encoder = EncoderModel(model.encoder, model.encoder_embed)
|
||||||
|
model.classifier = Classifier(model.classifier)
|
||||||
|
filename = "jit_script.pt"
|
||||||
|
|
||||||
|
logging.info("Using torch.jit.script")
|
||||||
|
model = torch.jit.script(model)
|
||||||
|
model.save(str(params.exp_dir / filename))
|
||||||
|
logging.info(f"Saved to {filename}")
|
||||||
|
else:
|
||||||
|
logging.info("Not using torchscript. Export model.state_dict()")
|
||||||
|
# Save it using a format so that it can be loaded
|
||||||
|
# by :func:`load_checkpoint`
|
||||||
|
filename = params.exp_dir / "pretrained.pt"
|
||||||
|
torch.save({"model": model.state_dict()}, str(filename))
|
||||||
|
logging.info(f"Saved to {filename}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
191
egs/audioset/AT/zipformer/jit_pretrained.py
Executable file
191
egs/audioset/AT/zipformer/jit_pretrained.py
Executable file
@ -0,0 +1,191 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao)
|
||||||
|
# 2024 Xiaoyu Yang
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This script loads torchscript models, exported by `torch.jit.script()`
|
||||||
|
and uses them to decode waves.
|
||||||
|
You can use the following command to get the exported models:
|
||||||
|
|
||||||
|
./zipformer/export.py \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 9 \
|
||||||
|
--jit 1
|
||||||
|
|
||||||
|
Usage of this script:
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
pushd $repo/exp
|
||||||
|
git lfs pull --include jit_script.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
python3 zipformer/jit_pretrained.py \
|
||||||
|
--nn-model-filename $repo/exp/jit_script.pt \
|
||||||
|
--label-dict $repo/data/class_labels_indices.csv \
|
||||||
|
$repo/test_wavs/1.wav \
|
||||||
|
$repo/test_wavs/2.wav \
|
||||||
|
$repo/test_wavs/3.wav \
|
||||||
|
$repo/test_wavs/4.wav
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import kaldifeat
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nn-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the torchscript model cpu_jit.pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--label-dict",
|
||||||
|
type=str,
|
||||||
|
help="""class_labels_indices.csv.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_files",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="The input sound file(s) to transcribe. "
|
||||||
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
|
"For example, wav and flac are supported. "
|
||||||
|
"The sample rate has to be 16kHz.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def read_sound_files(
|
||||||
|
filenames: List[str], expected_sample_rate: float = 16000
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
A list of sound filenames.
|
||||||
|
expected_sample_rate:
|
||||||
|
The expected sample rate of the sound files.
|
||||||
|
Returns:
|
||||||
|
Return a list of 1-D float32 torch tensors.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for f in filenames:
|
||||||
|
wave, sample_rate = torchaudio.load(f)
|
||||||
|
assert (
|
||||||
|
sample_rate == expected_sample_rate
|
||||||
|
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||||
|
# We use only the first channel
|
||||||
|
ans.append(wave[0].contiguous())
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
model = torch.jit.load(args.nn_model_filename)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
# get the label dictionary
|
||||||
|
label_dict = {}
|
||||||
|
with open(args.label_dict, "r") as f:
|
||||||
|
reader = csv.reader(f, delimiter=",")
|
||||||
|
for i, row in enumerate(reader):
|
||||||
|
if i == 0:
|
||||||
|
continue
|
||||||
|
label_dict[int(row[0])] = row[2]
|
||||||
|
|
||||||
|
logging.info("Constructing Fbank computer")
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.device = device
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
opts.frame_opts.samp_freq = 16000
|
||||||
|
opts.mel_opts.num_bins = 80
|
||||||
|
opts.mel_opts.high_freq = -400
|
||||||
|
|
||||||
|
fbank = kaldifeat.Fbank(opts)
|
||||||
|
|
||||||
|
logging.info(f"Reading sound files: {args.sound_files}")
|
||||||
|
waves = read_sound_files(
|
||||||
|
filenames=args.sound_files,
|
||||||
|
)
|
||||||
|
waves = [w.to(device) for w in waves]
|
||||||
|
|
||||||
|
logging.info("Decoding started")
|
||||||
|
features = fbank(waves)
|
||||||
|
feature_lengths = [f.size(0) for f in features]
|
||||||
|
|
||||||
|
features = pad_sequence(
|
||||||
|
features,
|
||||||
|
batch_first=True,
|
||||||
|
padding_value=math.log(1e-10),
|
||||||
|
)
|
||||||
|
|
||||||
|
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||||
|
|
||||||
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
|
features=features,
|
||||||
|
feature_lengths=feature_lengths,
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = model.classifier(encoder_out, encoder_out_lens)
|
||||||
|
|
||||||
|
for filename, logit in zip(args.sound_files, logits):
|
||||||
|
topk_prob, topk_index = logit.sigmoid().topk(5)
|
||||||
|
topk_labels = [label_dict[index.item()] for index in topk_index]
|
||||||
|
logging.info(
|
||||||
|
f"{filename}: Top 5 predicted labels are {topk_labels} with "
|
||||||
|
f"probability of {topk_prob.tolist()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
153
egs/audioset/AT/zipformer/model.py
Normal file
153
egs/audioset/AT/zipformer/model.py
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
# Copyright 2021-2023 Xiaomi Corp. (authors: Xiaoyu Yang,
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from encoder_interface import EncoderInterface
|
||||||
|
|
||||||
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
|
class AudioTaggingModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder_embed: nn.Module,
|
||||||
|
encoder: EncoderInterface,
|
||||||
|
encoder_dim: int = 384,
|
||||||
|
num_events: int = 527,
|
||||||
|
):
|
||||||
|
"""An audio tagging model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_embed:
|
||||||
|
It is a Convolutional 2D subsampling module. It converts
|
||||||
|
an input of shape (N, T, idim) to an output of of shape
|
||||||
|
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
|
||||||
|
encoder:
|
||||||
|
It is the transcription network in the paper. Its accepts
|
||||||
|
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||||
|
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
|
||||||
|
`logit_lens` of shape (N,).
|
||||||
|
encoder_dim:
|
||||||
|
Dimension of the encoder.
|
||||||
|
num_event:
|
||||||
|
The number of classes.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||||
|
|
||||||
|
self.encoder_embed = encoder_embed
|
||||||
|
self.encoder = encoder
|
||||||
|
self.encoder_dim = encoder_dim
|
||||||
|
|
||||||
|
self.classifier = nn.Sequential(
|
||||||
|
nn.Dropout(0.1),
|
||||||
|
nn.Linear(encoder_dim, num_events),
|
||||||
|
)
|
||||||
|
|
||||||
|
# for multi-class classification
|
||||||
|
self.criterion = torch.nn.BCEWithLogitsLoss(reduction="sum")
|
||||||
|
|
||||||
|
def forward_encoder(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_lens: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Compute encoder outputs.
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A 3-D tensor of shape (N, T, C).
|
||||||
|
x_lens:
|
||||||
|
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||||
|
before padding.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
encoder_out:
|
||||||
|
Encoder output, of shape (N, T, C).
|
||||||
|
encoder_out_lens:
|
||||||
|
Encoder output lengths, of shape (N,).
|
||||||
|
"""
|
||||||
|
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
|
||||||
|
x, x_lens = self.encoder_embed(x, x_lens)
|
||||||
|
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
|
||||||
|
|
||||||
|
src_key_padding_mask = make_pad_mask(x_lens)
|
||||||
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
|
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||||
|
|
||||||
|
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
|
||||||
|
|
||||||
|
return encoder_out, encoder_out_lens
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_lens: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A 3-D tensor of shape (N, T, C).
|
||||||
|
x_lens:
|
||||||
|
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||||
|
before padding.
|
||||||
|
target:
|
||||||
|
The ground truth label of audio events, could be many hot
|
||||||
|
Returns:
|
||||||
|
Return the binary crossentropy loss
|
||||||
|
"""
|
||||||
|
assert x.ndim == 3, x.shape
|
||||||
|
assert x_lens.ndim == 1, x_lens.shape
|
||||||
|
|
||||||
|
# Compute encoder outputs
|
||||||
|
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
||||||
|
|
||||||
|
# Forward the speaker module
|
||||||
|
logits = self.forward_audio_tagging(
|
||||||
|
encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
|
||||||
|
) # (N, num_classes)
|
||||||
|
|
||||||
|
loss = self.criterion(logits, target)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def forward_audio_tagging(self, encoder_out, encoder_out_lens):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder_out:
|
||||||
|
A 3-D tensor of shape (N, T, C).
|
||||||
|
encoder_out_lens:
|
||||||
|
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||||
|
before padding.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A 3-D tensor of shape (N, num_classes).
|
||||||
|
"""
|
||||||
|
logits = self.classifier(encoder_out) # (N, T, num_classes)
|
||||||
|
padding_mask = make_pad_mask(encoder_out_lens)
|
||||||
|
logits[padding_mask] = 0
|
||||||
|
logits = logits.sum(dim=1) # mask the padding frames
|
||||||
|
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(
|
||||||
|
logits
|
||||||
|
) # normalize the logits
|
||||||
|
|
||||||
|
return logits
|
227
egs/audioset/AT/zipformer/onnx_pretrained.py
Executable file
227
egs/audioset/AT/zipformer/onnx_pretrained.py
Executable file
@ -0,0 +1,227 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
# 2022 Xiaomi Corp. (authors: Xiaoyu Yang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This script loads ONNX models and uses them to decode waves.
|
||||||
|
|
||||||
|
Usage of this script:
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/k2-fsa/sherpa-onnx-zipformer-audio-tagging-2024-04-09
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
git clone $repo_url
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "*.onnx"
|
||||||
|
popd
|
||||||
|
|
||||||
|
for m in model.onnx model.int8.onnx; do
|
||||||
|
python3 zipformer/onnx_pretrained.py \
|
||||||
|
--model-filename $repo/model.onnx \
|
||||||
|
--label-dict $repo/class_labels_indices.csv \
|
||||||
|
$repo/test_wavs/1.wav \
|
||||||
|
$repo/test_wavs/2.wav \
|
||||||
|
$repo/test_wavs/3.wav \
|
||||||
|
$repo/test_wavs/4.wav
|
||||||
|
done
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import kaldifeat
|
||||||
|
import onnxruntime as ort
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the onnx model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--label-dict",
|
||||||
|
type=str,
|
||||||
|
help="""class_labels_indices.csv.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_files",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="The input sound file(s) to transcribe. "
|
||||||
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
|
"For example, wav and flac are supported. "
|
||||||
|
"The sample rate has to be 16kHz.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample-rate",
|
||||||
|
type=int,
|
||||||
|
default=16000,
|
||||||
|
help="The sample rate of the input sound file",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxModel:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
nn_model: str,
|
||||||
|
):
|
||||||
|
session_opts = ort.SessionOptions()
|
||||||
|
session_opts.inter_op_num_threads = 1
|
||||||
|
session_opts.intra_op_num_threads = 4
|
||||||
|
|
||||||
|
self.session_opts = session_opts
|
||||||
|
|
||||||
|
self.init_model(nn_model)
|
||||||
|
|
||||||
|
def init_model(self, nn_model: str):
|
||||||
|
self.model = ort.InferenceSession(
|
||||||
|
nn_model,
|
||||||
|
sess_options=self.session_opts,
|
||||||
|
providers=["CPUExecutionProvider"],
|
||||||
|
)
|
||||||
|
meta = self.model.get_modelmeta().custom_metadata_map
|
||||||
|
print(meta)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_lens: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A 3-D tensor of shape (N, T, C)
|
||||||
|
x_lens:
|
||||||
|
A 2-D tensor of shape (N,). Its dtype is torch.int64
|
||||||
|
Returns:
|
||||||
|
Return a Tensor:
|
||||||
|
- probs, its shape is (N, num_classes)
|
||||||
|
"""
|
||||||
|
out = self.model.run(
|
||||||
|
[
|
||||||
|
self.model.get_outputs()[0].name,
|
||||||
|
],
|
||||||
|
{
|
||||||
|
self.model.get_inputs()[0].name: x.numpy(),
|
||||||
|
self.model.get_inputs()[1].name: x_lens.numpy(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return torch.from_numpy(out[0])
|
||||||
|
|
||||||
|
|
||||||
|
def read_sound_files(
|
||||||
|
filenames: List[str], expected_sample_rate: float
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
A list of sound filenames.
|
||||||
|
expected_sample_rate:
|
||||||
|
The expected sample rate of the sound files.
|
||||||
|
Returns:
|
||||||
|
Return a list of 1-D float32 torch tensors.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for f in filenames:
|
||||||
|
wave, sample_rate = torchaudio.load(f)
|
||||||
|
assert (
|
||||||
|
sample_rate == expected_sample_rate
|
||||||
|
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||||
|
# We use only the first channel
|
||||||
|
ans.append(wave[0])
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
model = OnnxModel(
|
||||||
|
nn_model=args.model_filename,
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the label dictionary
|
||||||
|
label_dict = {}
|
||||||
|
with open(args.label_dict, "r") as f:
|
||||||
|
reader = csv.reader(f, delimiter=",")
|
||||||
|
for i, row in enumerate(reader):
|
||||||
|
if i == 0:
|
||||||
|
continue
|
||||||
|
label_dict[int(row[0])] = row[2]
|
||||||
|
|
||||||
|
logging.info("Constructing Fbank computer")
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.device = "cpu"
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
opts.frame_opts.samp_freq = args.sample_rate
|
||||||
|
opts.mel_opts.num_bins = 80
|
||||||
|
opts.mel_opts.high_freq = -400
|
||||||
|
|
||||||
|
fbank = kaldifeat.Fbank(opts)
|
||||||
|
|
||||||
|
logging.info(f"Reading sound files: {args.sound_files}")
|
||||||
|
waves = read_sound_files(
|
||||||
|
filenames=args.sound_files,
|
||||||
|
expected_sample_rate=args.sample_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Decoding started")
|
||||||
|
features = fbank(waves)
|
||||||
|
feature_lengths = [f.size(0) for f in features]
|
||||||
|
|
||||||
|
features = pad_sequence(
|
||||||
|
features,
|
||||||
|
batch_first=True,
|
||||||
|
padding_value=math.log(1e-10),
|
||||||
|
)
|
||||||
|
|
||||||
|
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
|
||||||
|
probs = model(features, feature_lengths)
|
||||||
|
|
||||||
|
for filename, prob in zip(args.sound_files, probs):
|
||||||
|
topk_prob, topk_index = prob.topk(5)
|
||||||
|
topk_labels = [label_dict[index.item()] for index in topk_index]
|
||||||
|
logging.info(
|
||||||
|
f"{filename}: Top 5 predicted labels are {topk_labels} with "
|
||||||
|
f"probability of {topk_prob.tolist()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Decoding Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
1
egs/audioset/AT/zipformer/optim.py
Symbolic link
1
egs/audioset/AT/zipformer/optim.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/zipformer/optim.py
|
202
egs/audioset/AT/zipformer/pretrained.py
Executable file
202
egs/audioset/AT/zipformer/pretrained.py
Executable file
@ -0,0 +1,202 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2024 Xiaomi Corp. (authors: Xiaoyu Yang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This script loads a checkpoint and uses it to decode waves.
|
||||||
|
You can generate the checkpoint with the following command:
|
||||||
|
|
||||||
|
Note: This is an example for the AudioSet dataset, if you are using different
|
||||||
|
dataset, you should change the argument values according to your dataset.
|
||||||
|
|
||||||
|
Usage of this script:
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
pushd $repo/exp
|
||||||
|
git lfs pull --include pretrained.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
python3 zipformer/pretrained.py \
|
||||||
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
|
--label-dict $repo/data/class_labels_indices.csv \
|
||||||
|
$repo/test_wavs/1.wav \
|
||||||
|
$repo/test_wavs/2.wav \
|
||||||
|
$repo/test_wavs/3.wav \
|
||||||
|
$repo/test_wavs/4.wav
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import kaldifeat
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the checkpoint. "
|
||||||
|
"The checkpoint is assumed to be saved by "
|
||||||
|
"icefall.checkpoint.save_checkpoint().",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--label-dict",
|
||||||
|
type=str,
|
||||||
|
help="""class_labels_indices.csv.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_files",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="The input sound file(s) to transcribe. "
|
||||||
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
|
"For example, wav and flac are supported. "
|
||||||
|
"The sample rate has to be 16kHz.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample-rate",
|
||||||
|
type=int,
|
||||||
|
default=16000,
|
||||||
|
help="The sample rate of the input sound file",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def read_sound_files(
|
||||||
|
filenames: List[str], expected_sample_rate: float
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
A list of sound filenames.
|
||||||
|
expected_sample_rate:
|
||||||
|
The expected sample rate of the sound files.
|
||||||
|
Returns:
|
||||||
|
Return a list of 1-D float32 torch tensors.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for f in filenames:
|
||||||
|
wave, sample_rate = torchaudio.load(f)
|
||||||
|
assert (
|
||||||
|
sample_rate == expected_sample_rate
|
||||||
|
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||||
|
# We use only the first channel
|
||||||
|
ans.append(wave[0].contiguous())
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
logging.info(f"{params}")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
logging.info("Creating model")
|
||||||
|
model = get_model(params)
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||||
|
model.load_state_dict(checkpoint["model"], strict=False)
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# get the label dictionary
|
||||||
|
label_dict = {}
|
||||||
|
with open(params.label_dict, "r") as f:
|
||||||
|
reader = csv.reader(f, delimiter=",")
|
||||||
|
for i, row in enumerate(reader):
|
||||||
|
if i == 0:
|
||||||
|
continue
|
||||||
|
label_dict[int(row[0])] = row[2]
|
||||||
|
|
||||||
|
logging.info("Constructing Fbank computer")
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.device = device
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
opts.frame_opts.samp_freq = params.sample_rate
|
||||||
|
opts.mel_opts.num_bins = params.feature_dim
|
||||||
|
opts.mel_opts.high_freq = -400
|
||||||
|
|
||||||
|
fbank = kaldifeat.Fbank(opts)
|
||||||
|
|
||||||
|
logging.info(f"Reading sound files: {params.sound_files}")
|
||||||
|
waves = read_sound_files(
|
||||||
|
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||||
|
)
|
||||||
|
waves = [w.to(device) for w in waves]
|
||||||
|
|
||||||
|
logging.info("Decoding started")
|
||||||
|
features = fbank(waves)
|
||||||
|
feature_lengths = [f.size(0) for f in features]
|
||||||
|
|
||||||
|
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||||
|
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||||
|
|
||||||
|
# model forward and predict the audio events
|
||||||
|
encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths)
|
||||||
|
logits = model.forward_audio_tagging(encoder_out, encoder_out_lens)
|
||||||
|
|
||||||
|
for filename, logit in zip(args.sound_files, logits):
|
||||||
|
topk_prob, topk_index = logit.sigmoid().topk(5)
|
||||||
|
topk_labels = [label_dict[index.item()] for index in topk_index]
|
||||||
|
logging.info(
|
||||||
|
f"{filename}: Top 5 predicted labels are {topk_labels} with "
|
||||||
|
f"probability of {topk_prob.tolist()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
main()
|
1
egs/audioset/AT/zipformer/scaling.py
Symbolic link
1
egs/audioset/AT/zipformer/scaling.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/zipformer/scaling.py
|
1
egs/audioset/AT/zipformer/scaling_converter.py
Symbolic link
1
egs/audioset/AT/zipformer/scaling_converter.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/zipformer/scaling_converter.py
|
1
egs/audioset/AT/zipformer/subsampling.py
Symbolic link
1
egs/audioset/AT/zipformer/subsampling.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/zipformer/subsampling.py
|
1185
egs/audioset/AT/zipformer/train.py
Normal file
1185
egs/audioset/AT/zipformer/train.py
Normal file
File diff suppressed because it is too large
Load Diff
1
egs/audioset/AT/zipformer/zipformer.py
Symbolic link
1
egs/audioset/AT/zipformer/zipformer.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/zipformer/zipformer.py
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user