Merge branch 'master' into einichi

This commit is contained in:
Machiko Bailey 2025-01-07 14:39:57 +09:00 committed by GitHub
commit 5c142d4c60
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
297 changed files with 30911 additions and 1750 deletions

167
.github/scripts/baker_zh/TTS/run-matcha.sh vendored Executable file
View File

@ -0,0 +1,167 @@
#!/usr/bin/env bash
set -ex
apt-get update
apt-get install -y sox
python3 -m pip install numba conformer==0.3.2 diffusers librosa
python3 -m pip install jieba
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/baker_zh/TTS
sed -i.bak s/600/8/g ./prepare.sh
sed -i.bak s/"first 100"/"first 3"/g ./prepare.sh
sed -i.bak s/500/5/g ./prepare.sh
git diff
function prepare_data() {
# We have created a subset of the data for testing
#
mkdir -p download
pushd download
wget -q https://huggingface.co/csukuangfj/tmp-files/resolve/main/BZNSYP-samples.tar.bz2
tar xvf BZNSYP-samples.tar.bz2
mv BZNSYP-samples BZNSYP
rm BZNSYP-samples.tar.bz2
popd
./prepare.sh
tree .
}
function train() {
pushd ./matcha
sed -i.bak s/1500/3/g ./train.py
git diff .
popd
./matcha/train.py \
--exp-dir matcha/exp \
--num-epochs 1 \
--save-every-n 1 \
--num-buckets 2 \
--tokens data/tokens.txt \
--max-duration 20
ls -lh matcha/exp
}
function infer() {
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2
./matcha/infer.py \
--num-buckets 2 \
--epoch 1 \
--exp-dir ./matcha/exp \
--tokens data/tokens.txt \
--cmvn ./data/fbank/cmvn.json \
--vocoder ./generator_v2 \
--input-text "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。" \
--output-wav ./generated.wav
ls -lh *.wav
soxi ./generated.wav
rm -v ./generated.wav
rm -v generator_v2
}
function export_onnx() {
pushd matcha/exp
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-baker-matcha-zh-2024-12-27/resolve/main/epoch-2000.pt
popd
pushd data/fbank
rm -v *.json
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-baker-matcha-zh-2024-12-27/resolve/main/cmvn.json
popd
./matcha/export_onnx.py \
--exp-dir ./matcha/exp \
--epoch 2000 \
--tokens ./data/tokens.txt \
--cmvn ./data/fbank/cmvn.json
ls -lh *.onnx
if false; then
# The CI machine does not have enough memory to run it
#
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v3
python3 ./matcha/export_onnx_hifigan.py
else
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v1.onnx
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v2.onnx
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v3.onnx
fi
ls -lh *.onnx
python3 ./matcha/generate_lexicon.py
for v in v1 v2 v3; do
python3 ./matcha/onnx_pretrained.py \
--acoustic-model ./model-steps-6.onnx \
--vocoder ./hifigan_$v.onnx \
--tokens ./data/tokens.txt \
--lexicon ./lexicon.txt \
--input-text "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。" \
--output-wav /icefall/generated-matcha-tts-steps-6-$v.wav
done
ls -lh /icefall/*.wav
soxi /icefall/generated-matcha-tts-steps-6-*.wav
cp ./model-steps-*.onnx /icefall
d=matcha-icefall-zh-baker
mkdir $d
cp -v data/tokens.txt $d
cp -v lexicon.txt $d
cp model-steps-3.onnx $d
pushd $d
curl -SL -O https://github.com/csukuangfj/cppjieba/releases/download/sherpa-onnx-2024-04-19/dict.tar.bz2
tar xvf dict.tar.bz2
rm dict.tar.bz2
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/date.fst
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/number.fst
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/phone.fst
cat >README.md <<EOF
# Introduction
This model is trained using the dataset from
https://en.data-baker.com/datasets/freeDatasets/
The dataset contains 10000 Chinese sentences of a native Chinese female speaker,
which is about 12 hours.
**Note**: The dataset is for non-commercial use only.
You can find the training code at
https://github.com/k2-fsa/icefall/tree/master/egs/baker_zh/TTS
EOF
ls -lh
popd
tar cvjf $d.tar.bz2 $d
mv $d.tar.bz2 /icefall
mv $d /icefall
}
prepare_data
train
infer
export_onnx
rm -rfv generator_v* matcha/exp
git checkout .

View File

@ -31,12 +31,15 @@ LABEL github_repo="https://github.com/k2-fsa/icefall"
# Install dependencies # Install dependencies
RUN pip install --no-cache-dir \ RUN pip install --no-cache-dir \
torch==${TORCH_VERSION} torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/cpu/torch_stable.html \ torch==${TORCH_VERSION}+cpu -f https://download.pytorch.org/whl/torch \
torchaudio==${TORCHAUDIO_VERSION}+cpu -f https://download.pytorch.org/whl/torchaudio \
k2==${_K2_VERSION} -f https://k2-fsa.github.io/k2/cpu.html \ k2==${_K2_VERSION} -f https://k2-fsa.github.io/k2/cpu.html \
\ \
git+https://github.com/lhotse-speech/lhotse \ git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${_KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cpu.html \ kaldifeat==${_KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cpu.html \
conformer==0.3.2 \
cython \ cython \
diffusers \
dill \ dill \
espnet_tts_frontend \ espnet_tts_frontend \
graphviz \ graphviz \
@ -45,10 +48,11 @@ RUN pip install --no-cache-dir \
kaldialign \ kaldialign \
kaldifst \ kaldifst \
kaldilm \ kaldilm \
matplotlib \ librosa \
"matplotlib<=3.9.4" \
multi_quantization \ multi_quantization \
numba \ numba \
numpy \ "numpy<2.0" \
onnxoptimizer \ onnxoptimizer \
onnxsim \ onnxsim \
onnx \ onnx \

View File

@ -2,9 +2,19 @@
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) # Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
import argparse
import json import json
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--min-torch-version",
help="Minimu torch version",
)
return parser.parse_args()
def version_gt(a, b): def version_gt(a, b):
a_major, a_minor = list(map(int, a.split(".")))[:2] a_major, a_minor = list(map(int, a.split(".")))[:2]
b_major, b_minor = list(map(int, b.split(".")))[:2] b_major, b_minor = list(map(int, b.split(".")))[:2]
@ -42,22 +52,34 @@ def get_torchaudio_version(torch_version):
return torch_version return torch_version
def get_matrix():
k2_version = "1.24.4.dev20240223" def get_matrix(min_torch_version):
kaldifeat_version = "1.25.4.dev20240223" k2_version = "1.24.4.dev20241029"
version = "20240725" kaldifeat_version = "1.25.5.dev20241029"
version = "20241218"
# torchaudio 2.5.0 does not support python 3.13
python_version = ["3.8", "3.9", "3.10", "3.11", "3.12"] python_version = ["3.8", "3.9", "3.10", "3.11", "3.12"]
torch_version = [] torch_version = []
torch_version += ["1.13.0", "1.13.1"] torch_version += ["1.13.0", "1.13.1"]
torch_version += ["2.0.0", "2.0.1"] torch_version += ["2.0.0", "2.0.1"]
torch_version += ["2.1.0", "2.1.1", "2.1.2"] # torch_version += ["2.1.0", "2.1.1", "2.1.2"]
torch_version += ["2.2.0", "2.2.1", "2.2.2"] # torch_version += ["2.2.0", "2.2.1", "2.2.2"]
# Test only torch >= 2.3.0
torch_version += ["2.3.0", "2.3.1"] torch_version += ["2.3.0", "2.3.1"]
torch_version += ["2.4.0"] torch_version += ["2.4.0"]
torch_version += ["2.4.1"]
torch_version += ["2.5.0"]
torch_version += ["2.5.1"]
matrix = [] matrix = []
for p in python_version: for p in python_version:
for t in torch_version: for t in torch_version:
if min_torch_version and version_gt(min_torch_version, t):
continue
# torchaudio <= 1.13.x supports only python <= 3.10 # torchaudio <= 1.13.x supports only python <= 3.10
if version_gt(p, "3.10") and not version_gt(t, "2.0"): if version_gt(p, "3.10") and not version_gt(t, "2.0"):
@ -67,21 +89,20 @@ def get_matrix():
if version_gt(p, "3.11") and not version_gt(t, "2.1"): if version_gt(p, "3.11") and not version_gt(t, "2.1"):
continue continue
k2_version_2 = k2_version if version_gt(p, "3.12") and not version_gt(t, "2.4"):
kaldifeat_version_2 = kaldifeat_version continue
if t == "2.2.2": if version_gt(t, "2.4") and version_gt("3.10", p):
k2_version_2 = "1.24.4.dev20240328" # torch>=2.5 requires python 3.10
kaldifeat_version_2 = "1.25.4.dev20240329" continue
elif t == "2.3.0":
k2_version_2 = "1.24.4.dev20240425"
kaldifeat_version_2 = "1.25.4.dev20240425" if t == "2.5.1":
elif t == "2.3.1": k2_version_2 = "1.24.4.dev20241122"
k2_version_2 = "1.24.4.dev20240606" kaldifeat_version_2 = "1.25.5.dev20241126"
kaldifeat_version_2 = "1.25.4.dev20240606" else:
elif t == "2.4.0": k2_version_2 = k2_version
k2_version_2 = "1.24.4.dev20240725" kaldifeat_version_2 = kaldifeat_version
kaldifeat_version_2 = "1.25.4.dev20240725"
matrix.append( matrix.append(
{ {
@ -97,7 +118,8 @@ def get_matrix():
def main(): def main():
matrix = get_matrix() args = get_args()
matrix = get_matrix(min_torch_version=args.min_torch_version)
print(json.dumps({"include": matrix})) print(json.dumps({"include": matrix}))

157
.github/scripts/ljspeech/TTS/run-matcha.sh vendored Executable file
View File

@ -0,0 +1,157 @@
#!/usr/bin/env bash
set -ex
apt-get update
apt-get install -y sox
python3 -m pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html
python3 -m pip install espnet_tts_frontend
python3 -m pip install numba conformer==0.3.2 diffusers librosa
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/ljspeech/TTS
sed -i.bak s/600/8/g ./prepare.sh
sed -i.bak s/"first 100"/"first 3"/g ./prepare.sh
sed -i.bak s/500/5/g ./prepare.sh
git diff
function prepare_data() {
# We have created a subset of the data for testing
#
mkdir -p download
pushd download
wget -q https://huggingface.co/csukuangfj/ljspeech-subset-for-ci-test/resolve/main/LJSpeech-1.1.tar.bz2
tar xvf LJSpeech-1.1.tar.bz2
popd
./prepare.sh
tree .
}
function train() {
pushd ./matcha
sed -i.bak s/1500/3/g ./train.py
git diff .
popd
./matcha/train.py \
--exp-dir matcha/exp \
--num-epochs 1 \
--save-every-n 1 \
--num-buckets 2 \
--tokens data/tokens.txt \
--max-duration 20
ls -lh matcha/exp
}
function infer() {
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1
./matcha/infer.py \
--num-buckets 2 \
--epoch 1 \
--exp-dir ./matcha/exp \
--tokens data/tokens.txt \
--vocoder ./generator_v1 \
--input-text "how are you doing?" \
--output-wav ./generated.wav
ls -lh *.wav
soxi ./generated.wav
rm -v ./generated.wav
rm -v generator_v1
}
function export_onnx() {
pushd matcha/exp
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/epoch-4000.pt
popd
pushd data/fbank
rm -fv *.json
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/data/cmvn.json
popd
./matcha/export_onnx.py \
--exp-dir ./matcha/exp \
--epoch 4000 \
--tokens ./data/tokens.txt \
--cmvn ./data/fbank/cmvn.json
ls -lh *.onnx
if false; then
# The CI machine does not have enough memory to run it
#
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v3
python3 ./matcha/export_onnx_hifigan.py
else
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v1.onnx
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v2.onnx
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v3.onnx
fi
ls -lh *.onnx
for v in v1 v2 v3; do
python3 ./matcha/onnx_pretrained.py \
--acoustic-model ./model-steps-6.onnx \
--vocoder ./hifigan_$v.onnx \
--tokens ./data/tokens.txt \
--input-text "how are you doing?" \
--output-wav /icefall/generated-matcha-tts-steps-6-$v.wav
done
ls -lh /icefall/*.wav
soxi /icefall/generated-matcha-tts-steps-6-*.wav
cp ./model-steps-*.onnx /icefall
d=matcha-icefall-en_US-ljspeech
mkdir $d
cp -v data/tokens.txt $d
cp model-steps-3.onnx $d
pushd $d
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
tar xf espeak-ng-data.tar.bz2
rm espeak-ng-data.tar.bz2
cat >README.md <<EOF
# Introduction
This model is trained using the dataset from
https://keithito.com/LJ-Speech-Dataset/
The dataset contains only 1 female speaker.
You can find the training code at
https://github.com/k2-fsa/icefall/tree/master/egs/ljspeech/TTS#matcha
EOF
ls -lh
popd
tar cvjf $d.tar.bz2 $d
mv $d.tar.bz2 /icefall
mv $d /icefall
}
prepare_data
train
infer
export_onnx
rm -rfv generator_v* matcha/exp
git checkout .

View File

@ -22,7 +22,7 @@ git diff
function prepare_data() { function prepare_data() {
# We have created a subset of the data for testing # We have created a subset of the data for testing
# #
mkdir download mkdir -p download
pushd download pushd download
wget -q https://huggingface.co/csukuangfj/ljspeech-subset-for-ci-test/resolve/main/LJSpeech-1.1.tar.bz2 wget -q https://huggingface.co/csukuangfj/ljspeech-subset-for-ci-test/resolve/main/LJSpeech-1.1.tar.bz2
tar xvf LJSpeech-1.1.tar.bz2 tar xvf LJSpeech-1.1.tar.bz2

View File

@ -16,6 +16,48 @@ log "pwd: $PWD"
cd egs/multi_zh-hans/ASR cd egs/multi_zh-hans/ASR
repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2
log "Downloading pre-trained model from $repo_url"
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
cd exp
git lfs pull --include pretrained.pt
ln -s pretrained.pt epoch-99.pt
cd ../data/lang_bpe_2000
ls -lh
git lfs pull --include L.pt L_disambig.pt Linv.pt bpe.model
git lfs pull --include "*.model"
ls -lh
popd
log "--------------------------------------------"
log "Export non-streaming ONNX transducer models "
log "--------------------------------------------"
./zipformer/export-onnx.py \
--tokens $repo/data/lang_bpe_2000/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--causal False
ls -lh $repo/exp
./zipformer/onnx_pretrained.py \
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
--tokens $repo/data/lang_bpe_2000/tokens.txt \
$repo/test_wavs/DEV_T0000000000.wav \
$repo/test_wavs/DEV_T0000000001.wav \
$repo/test_wavs/DEV_T0000000002.wav \
$repo/test_wavs/TEST_MEETING_T0000000113.wav \
$repo/test_wavs/TEST_MEETING_T0000000219.wav \
$repo/test_wavs/TEST_MEETING_T0000000351.wav
rm -rf $repo
repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-streaming-2023-11-05 repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-streaming-2023-11-05
log "Downloading pre-trained model from $repo_url" log "Downloading pre-trained model from $repo_url"
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url

View File

@ -19,7 +19,7 @@ repo=$(basename $repo_url)
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_NAME}" == x"workflow_dispatch" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
mkdir -p pruned_transducer_stateless2/exp mkdir -p pruned_transducer_stateless2/exp
ln -s $PWD/$repo/exp/pretrained-iter-3488000-avg-20.pt pruned_transducer_stateless2/exp/epoch-999.pt ln -s $PWD/$repo/exp/pretrained-iter-3488000-avg-20.pt pruned_transducer_stateless2/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/ ln -s $PWD/$repo/data/lang_bpe_500 data/
@ -29,8 +29,16 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" ==
ls -lh data/fbank ls -lh data/fbank
ls -lh pruned_transducer_stateless2/exp ls -lh pruned_transducer_stateless2/exp
ln -s data/fbank/cuts_DEV.jsonl.gz data/fbank/gigaspeech_cuts_DEV.jsonl.gz pushd data/fbank
ln -s data/fbank/cuts_TEST.jsonl.gz data/fbank/gigaspeech_cuts_TEST.jsonl.gz curl -SL -O https://huggingface.co/csukuangfj/giga-dev-dataset-fbank/resolve/main/data/fbank/cuts_DEV.jsonl.gz
curl -SL -O https://huggingface.co/csukuangfj/giga-dev-dataset-fbank/resolve/main/data/fbank/cuts_TEST.jsonl.gz
curl -SL -O https://huggingface.co/csukuangfj/giga-dev-dataset-fbank/resolve/main/data/fbank/feats_DEV.lca
curl -SL -O https://huggingface.co/csukuangfj/giga-dev-dataset-fbank/resolve/main/data/fbank/feats_TEST.lca
ln -sf cuts_DEV.jsonl.gz gigaspeech_cuts_DEV.jsonl.gz
ln -sf cuts_TEST.jsonl.gz gigaspeech_cuts_TEST.jsonl.gz
popd
log "Decoding dev and test" log "Decoding dev and test"

View File

@ -129,20 +129,34 @@ done
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_NAME}" == x"workflow_dispatch" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
mkdir -p zipformer/exp mkdir -p zipformer/exp
ln -s $PWD/$repo/exp/pretrained.pt zipformer/exp/epoch-30.pt ln -s $PWD/$repo/exp/pretrained.pt zipformer/exp/epoch-30.pt
mkdir -p data
ln -s $PWD/$repo/data/lang_bpe_500 data/ ln -s $PWD/$repo/data/lang_bpe_500 data/
ls -lh data ls -lh data
ls -lh zipformer/exp ls -lh zipformer/exp
mkdir -p data/fbank
pushd data/fbank
curl -SL -O https://huggingface.co/csukuangfj/giga-dev-dataset-fbank/resolve/main/data/fbank/cuts_DEV.jsonl.gz
curl -SL -O https://huggingface.co/csukuangfj/giga-dev-dataset-fbank/resolve/main/data/fbank/cuts_TEST.jsonl.gz
curl -SL -O https://huggingface.co/csukuangfj/giga-dev-dataset-fbank/resolve/main/data/fbank/feats_DEV.lca
curl -SL -O https://huggingface.co/csukuangfj/giga-dev-dataset-fbank/resolve/main/data/fbank/feats_TEST.lca
ln -sf cuts_DEV.jsonl.gz gigaspeech_cuts_DEV.jsonl.gz
ln -sf cuts_TEST.jsonl.gz gigaspeech_cuts_TEST.jsonl.gz
popd
log "Decoding test-clean and test-other" log "Decoding test-clean and test-other"
# use a small value for decoding with CPU # use a small value for decoding with CPU
max_duration=100 max_duration=100
for method in greedy_search fast_beam_search modified_beam_search; do for method in greedy_search; do
log "Decoding with $method" log "Decoding with $method"
./zipformer/decode.py \ ./zipformer/decode.py \

View File

@ -162,7 +162,7 @@ if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"LODR" ]]; then
--ngram-lm-scale -0.16 --ngram-lm-scale -0.16
fi fi
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_NAME}" == x"workflow_dispatch" ]]; then
mkdir -p lstm_transducer_stateless2/exp mkdir -p lstm_transducer_stateless2/exp
ln -s $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt ln -s $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/ ln -s $PWD/$repo/data/lang_bpe_500 data/
@ -175,7 +175,7 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then
# use a small value for decoding with CPU # use a small value for decoding with CPU
max_duration=100 max_duration=100
for method in greedy_search fast_beam_search modified_beam_search; do for method in greedy_search fast_beam_search; do
log "Decoding with $method" log "Decoding with $method"
./lstm_transducer_stateless2/decode.py \ ./lstm_transducer_stateless2/decode.py \

View File

@ -25,6 +25,7 @@ popd
log "Export via torch.jit.script()" log "Export via torch.jit.script()"
./zipformer/export.py \ ./zipformer/export.py \
--use-averaged-model 0 \
--exp-dir $repo/exp \ --exp-dir $repo/exp \
--tokens $repo/data/lang_bpe_500/tokens.txt \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \ --epoch 99 \

View File

@ -83,7 +83,7 @@ jobs:
ls -lh ./model-onnx/* ls -lh ./model-onnx/*
- name: Upload model to huggingface - name: Upload model to huggingface
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push' if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
env: env:
HF_TOKEN: ${{ secrets.HF_TOKEN }} HF_TOKEN: ${{ secrets.HF_TOKEN }}
uses: nick-fields/retry@v3 uses: nick-fields/retry@v3
@ -116,7 +116,7 @@ jobs:
rm -rf huggingface rm -rf huggingface
- name: Prepare for release - name: Prepare for release
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push' if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
shell: bash shell: bash
run: | run: |
d=sherpa-onnx-zipformer-audio-tagging-2024-04-09 d=sherpa-onnx-zipformer-audio-tagging-2024-04-09
@ -125,7 +125,7 @@ jobs:
ls -lh ls -lh
- name: Release exported onnx models - name: Release exported onnx models
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push' if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
uses: svenstaro/upload-release-action@v2 uses: svenstaro/upload-release-action@v2
with: with:
file_glob: true file_glob: true

152
.github/workflows/baker_zh.yml vendored Normal file
View File

@ -0,0 +1,152 @@
name: baker_zh
on:
push:
branches:
- master
- baker-matcha-2
pull_request:
branches:
- master
workflow_dispatch:
concurrency:
group: baker-zh-${{ github.ref }}
cancel-in-progress: true
jobs:
generate_build_matrix:
if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa'
# see https://github.com/pytorch/pytorch/pull/50633
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Generating build matrix
id: set-matrix
run: |
# outputting for debugging purposes
python ./.github/scripts/docker/generate_build_matrix.py --min-torch-version "2.3"
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --min-torch-version "2.3")
echo "::set-output name=matrix::${MATRIX}"
baker_zh:
needs: generate_build_matrix
name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Free space
shell: bash
run: |
ls -lh
df -h
rm -rf /opt/hostedtoolcache
df -h
echo "pwd: $PWD"
echo "github.workspace ${{ github.workspace }}"
- name: Run tests
uses: addnab/docker-run-action@v3
with:
image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }}
options: |
--volume ${{ github.workspace }}/:/icefall
shell: bash
run: |
export PYTHONPATH=/icefall:$PYTHONPATH
cd /icefall
pip install onnx==1.17.0
pip list
git config --global --add safe.directory /icefall
.github/scripts/baker_zh/TTS/run-matcha.sh
- name: display files
shell: bash
run: |
ls -lh
- uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
with:
name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }}
path: ./*.wav
- uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
with:
name: step-2
path: ./model-steps-2.onnx
- uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
with:
name: step-3
path: ./model-steps-3.onnx
- uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
with:
name: step-4
path: ./model-steps-4.onnx
- uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
with:
name: step-5
path: ./model-steps-5.onnx
- uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
with:
name: step-6
path: ./model-steps-6.onnx
- name: Upload models to huggingface
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
shell: bash
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
d=matcha-icefall-zh-baker
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/$d hf
cp -av $d/* hf/
pushd hf
git add .
git config --global user.name "csukuangfj"
git config --global user.email "csukuangfj@gmail.com"
git config --global lfs.allowincompletepush true
git commit -m "upload model" && git push https://csukuangfj:${HF_TOKEN}@huggingface.co/csukuangfj/$d main || true
popd
- name: Release exported onnx models
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
uses: svenstaro/upload-release-action@v2
with:
file_glob: true
overwrite: true
file: matcha-icefall-*.tar.bz2
repo_name: k2-fsa/sherpa-onnx
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
tag: tts-models

View File

@ -26,6 +26,8 @@ on:
pull_request: pull_request:
types: [labeled] types: [labeled]
workflow_dispatch:
concurrency: concurrency:
group: build_doc-${{ github.ref }} group: build_doc-${{ github.ref }}
cancel-in-progress: true cancel-in-progress: true

View File

@ -16,7 +16,9 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
image: ["torch2.4.0-cuda12.4", "torch2.4.0-cuda12.1", "torch2.4.0-cuda11.8", "torch2.3.1-cuda12.1", "torch2.3.1-cuda11.8", "torch2.2.2-cuda12.1", "torch2.2.2-cuda11.8", "torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
image: ["torch2.4.1-cuda12.4", "torch2.4.1-cuda12.1", "torch2.4.1-cuda11.8", "torch2.4.0-cuda12.4", "torch2.4.0-cuda12.1", "torch2.4.0-cuda11.8", "torch2.3.1-cuda12.1", "torch2.3.1-cuda11.8", "torch2.2.2-cuda12.1", "torch2.2.2-cuda11.8", "torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
steps: steps:
# refer to https://github.com/actions/checkout # refer to https://github.com/actions/checkout

View File

@ -30,8 +30,8 @@ jobs:
id: set-matrix id: set-matrix
run: | run: |
# outputting for debugging purposes # outputting for debugging purposes
python ./.github/scripts/docker/generate_build_matrix.py python ./.github/scripts/docker/generate_build_matrix.py --min-torch-version "2.3"
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --min-torch-version "2.3")
echo "::set-output name=matrix::${MATRIX}" echo "::set-output name=matrix::${MATRIX}"
ljspeech: ljspeech:
@ -70,6 +70,11 @@ jobs:
cd /icefall cd /icefall
git config --global --add safe.directory /icefall git config --global --add safe.directory /icefall
pip install "matplotlib<=3.9.4"
pip list
.github/scripts/ljspeech/TTS/run-matcha.sh
.github/scripts/ljspeech/TTS/run.sh .github/scripts/ljspeech/TTS/run.sh
- name: display files - name: display files
@ -78,19 +83,13 @@ jobs:
ls -lh ls -lh
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
with: with:
name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }} name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }}
path: ./*.wav path: ./*.wav
- uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0'
with:
name: generated-models-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}
path: ./*.wav
- name: Release exported onnx models - name: Release exported onnx models
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push' if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
uses: svenstaro/upload-release-action@v2 uses: svenstaro/upload-release-action@v2
with: with:
file_glob: true file_glob: true
@ -100,3 +99,68 @@ jobs:
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
tag: tts-models tag: tts-models
- uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
with:
name: step-2
path: ./model-steps-2.onnx
- uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
with:
name: step-3
path: ./model-steps-3.onnx
- uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
with:
name: step-4
path: ./model-steps-4.onnx
- uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
with:
name: step-5
path: ./model-steps-5.onnx
- uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
with:
name: step-6
path: ./model-steps-6.onnx
- name: Upload models to huggingface
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
shell: bash
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
d=matcha-icefall-en_US-ljspeech
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/$d hf
cp -av $d/* hf/
pushd hf
git lfs track "cmn_dict"
git lfs track "ru_dict"
git add .
git config --global user.name "csukuangfj"
git config --global user.email "csukuangfj@gmail.com"
git config --global lfs.allowincompletepush true
git commit -m "upload model" && git push https://csukuangfj:${HF_TOKEN}@huggingface.co/csukuangfj/$d main || true
popd
- name: Release exported onnx models
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
uses: svenstaro/upload-release-action@v2
with:
file_glob: true
overwrite: true
file: matcha-icefall-*.tar.bz2
repo_name: k2-fsa/sherpa-onnx
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
tag: tts-models

View File

@ -33,13 +33,15 @@ on:
# nightly build at 15:50 UTC time every day # nightly build at 15:50 UTC time every day
- cron: "50 15 * * *" - cron: "50 15 * * *"
workflow_dispatch:
concurrency: concurrency:
group: run_gigaspeech_2022_05_13-${{ github.ref }} group: run_gigaspeech_2022_05_13-${{ github.ref }}
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
run_gigaspeech_2022_05_13: run_gigaspeech_2022_05_13:
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' if: github.event_name == 'workflow_dispatch' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
strategy: strategy:
matrix: matrix:
@ -104,7 +106,7 @@ jobs:
.github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh .github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh
- name: Display decoding results for gigaspeech pruned_transducer_stateless2 - name: Display decoding results for gigaspeech pruned_transducer_stateless2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' if: github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' || github.event.label.name == 'run-decode'
shell: bash shell: bash
run: | run: |
cd egs/gigaspeech/ASR/ cd egs/gigaspeech/ASR/
@ -119,8 +121,8 @@ jobs:
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test" {} + | sort -n -k2 find exp/greedy_search -name "log-*" -exec grep -n --color "best for test" {} + | sort -n -k2
- name: Upload decoding results for gigaspeech pruned_transducer_stateless2 - name: Upload decoding results for gigaspeech pruned_transducer_stateless2
uses: actions/upload-artifact@v2 uses: actions/upload-artifact@v4
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' if: github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' || github.event.label.name == 'run-decode'
with: with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-gigaspeech-pruned_transducer_stateless2-2022-05-12 name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-gigaspeech-pruned_transducer_stateless2-2022-05-12
path: egs/gigaspeech/ASR/pruned_transducer_stateless2/exp/ path: egs/gigaspeech/ASR/pruned_transducer_stateless2/exp/

View File

@ -42,7 +42,7 @@ concurrency:
jobs: jobs:
run_gigaspeech_2023_10_17_zipformer: run_gigaspeech_2023_10_17_zipformer:
if: github.event.label.name == 'zipformer' ||github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' if: github.event.label.name == 'zipformer' ||github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch'
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
strategy: strategy:
matrix: matrix:
@ -90,10 +90,6 @@ jobs:
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
HF_TOKEN: ${{ secrets.HF_TOKEN }} HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: | run: |
mkdir -p egs/gigaspeech/ASR/data
ln -sfv ~/tmp/fbank-libri egs/gigaspeech/ASR/data/fbank
ls -lh egs/gigaspeech/ASR/data/*
sudo apt-get -qq install git-lfs tree sudo apt-get -qq install git-lfs tree
export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
@ -112,7 +108,7 @@ jobs:
tag: asr-models tag: asr-models
- name: Display decoding results for gigaspeech zipformer - name: Display decoding results for gigaspeech zipformer
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' || github.event_name == 'workflow_dispatch'
shell: bash shell: bash
run: | run: |
cd egs/gigaspeech/ASR/ cd egs/gigaspeech/ASR/
@ -124,17 +120,17 @@ jobs:
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===fast_beam_search===" # echo "===fast_beam_search==="
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 # find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 # find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
#
echo "===modified beam search===" # echo "===modified beam search==="
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 # find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 # find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Upload decoding results for gigaspeech zipformer - name: Upload decoding results for gigaspeech zipformer
uses: actions/upload-artifact@v2 uses: actions/upload-artifact@v4
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' || github.event_name == 'workflow_dispatch'
with: with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer-2022-11-11 name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer-2022-11-11
path: egs/gigaspeech/ASR/zipformer/exp/ path: egs/gigaspeech/ASR/zipformer/exp/

View File

@ -16,13 +16,15 @@ on:
# nightly build at 15:50 UTC time every day # nightly build at 15:50 UTC time every day
- cron: "50 15 * * *" - cron: "50 15 * * *"
workflow_dispatch:
concurrency: concurrency:
group: run_librispeech_lstm_transducer_stateless2_2022_09_03-${{ github.ref }} group: run_librispeech_lstm_transducer_stateless2_2022_09_03-${{ github.ref }}
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
run_librispeech_lstm_transducer_stateless2_2022_09_03: run_librispeech_lstm_transducer_stateless2_2022_09_03:
if: github.event.label.name == 'ready' || github.event.label.name == 'LODR' || github.event.label.name == 'shallow-fusion' || github.event_name == 'push' || github.event_name == 'schedule' if: github.event.label.name == 'ready' || github.event.label.name == 'LODR' || github.event.label.name == 'shallow-fusion' || github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch'
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
strategy: strategy:
matrix: matrix:
@ -114,7 +116,7 @@ jobs:
.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh .github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh
- name: Display decoding results for lstm_transducer_stateless2 - name: Display decoding results for lstm_transducer_stateless2
if: github.event_name == 'schedule' if: github.event_name == 'schedule' || github.event_name == 'workflow_dispatch'
shell: bash shell: bash
run: | run: |
cd egs/librispeech/ASR cd egs/librispeech/ASR
@ -128,9 +130,9 @@ jobs:
find fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 find fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
echo "===modified beam search===" # echo "===modified beam search==="
find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 # find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 # find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Display decoding results for lstm_transducer_stateless2 - name: Display decoding results for lstm_transducer_stateless2
if: github.event.label.name == 'shallow-fusion' if: github.event.label.name == 'shallow-fusion'
@ -156,8 +158,8 @@ jobs:
find modified_beam_search_LODR -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 find modified_beam_search_LODR -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Upload decoding results for lstm_transducer_stateless2 - name: Upload decoding results for lstm_transducer_stateless2
uses: actions/upload-artifact@v2 uses: actions/upload-artifact@v4
if: github.event_name == 'schedule' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'LODR' if: github.event_name == 'schedule' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'LODR' || github.event_name == 'workflow_dispatch'
with: with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-lstm_transducer_stateless2-2022-09-03 name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-lstm_transducer_stateless2-2022-09-03
path: egs/librispeech/ASR/lstm_transducer_stateless2/exp/ path: egs/librispeech/ASR/lstm_transducer_stateless2/exp/

View File

@ -23,6 +23,8 @@ on:
pull_request: pull_request:
types: [labeled] types: [labeled]
workflow_dispatch:
concurrency: concurrency:
group: run_multi-corpora_zipformer-${{ github.ref }} group: run_multi-corpora_zipformer-${{ github.ref }}
cancel-in-progress: true cancel-in-progress: true

View File

@ -16,6 +16,8 @@ on:
# nightly build at 15:50 UTC time every day # nightly build at 15:50 UTC time every day
- cron: "50 15 * * *" - cron: "50 15 * * *"
workflow_dispatch:
concurrency: concurrency:
group: run_ptb_rnn_lm_training-${{ github.ref }} group: run_ptb_rnn_lm_training-${{ github.ref }}
cancel-in-progress: true cancel-in-progress: true
@ -64,7 +66,7 @@ jobs:
./train-rnn-lm.sh --world-size 1 --num-epochs 5 --use-epoch 4 --use-avg 2 ./train-rnn-lm.sh --world-size 1 --num-epochs 5 --use-epoch 4 --use-avg 2
- name: Upload pretrained models - name: Upload pretrained models
uses: actions/upload-artifact@v2 uses: actions/upload-artifact@v4
if: github.event.label.name == 'ready' || github.event.label.name == 'rnnlm' || github.event_name == 'push' || github.event_name == 'schedule' if: github.event.label.name == 'ready' || github.event.label.name == 'rnnlm' || github.event_name == 'push' || github.event_name == 'schedule'
with: with:
name: python-${{ matrix.python-version }}-ubuntu-rnn-lm-ptb name: python-${{ matrix.python-version }}-ubuntu-rnn-lm-ptb

View File

@ -23,6 +23,8 @@ on:
pull_request: pull_request:
types: [labeled] types: [labeled]
workflow_dispatch:
concurrency: concurrency:
group: run-swbd-conformer_ctc-${{ github.ref }} group: run-swbd-conformer_ctc-${{ github.ref }}
cancel-in-progress: true cancel-in-progress: true

View File

@ -23,6 +23,8 @@ on:
pull_request: pull_request:
types: [labeled] types: [labeled]
workflow_dispatch:
concurrency: concurrency:
group: run_wenetspeech_pruned_transducer_stateless2-${{ github.ref }} group: run_wenetspeech_pruned_transducer_stateless2-${{ github.ref }}
cancel-in-progress: true cancel-in-progress: true

View File

@ -24,6 +24,8 @@ on:
branches: branches:
- master - master
workflow_dispatch:
concurrency: concurrency:
group: style_check-${{ github.ref }} group: style_check-${{ github.ref }}
cancel-in-progress: true cancel-in-progress: true
@ -34,7 +36,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.8] python-version: [3.10.15]
fail-fast: false fail-fast: false
steps: steps:

View File

@ -16,6 +16,8 @@ on:
# nightly build at 15:50 UTC time every day # nightly build at 15:50 UTC time every day
- cron: "50 15 * * *" - cron: "50 15 * * *"
workflow_dispatch:
concurrency: concurrency:
group: test_ncnn_export-${{ github.ref }} group: test_ncnn_export-${{ github.ref }}
cancel-in-progress: true cancel-in-progress: true

View File

@ -16,6 +16,8 @@ on:
# nightly build at 15:50 UTC time every day # nightly build at 15:50 UTC time every day
- cron: "50 15 * * *" - cron: "50 15 * * *"
workflow_dispatch:
concurrency: concurrency:
group: test_onnx_export-${{ github.ref }} group: test_onnx_export-${{ github.ref }}
cancel-in-progress: true cancel-in-progress: true

View File

@ -105,7 +105,7 @@ jobs:
cd ../zipformer cd ../zipformer
pytest -v -s pytest -v -s
- uses: actions/upload-artifact@v2 - uses: actions/upload-artifact@v4
with: with:
path: egs/librispeech/ASR/zipformer/swoosh.pdf path: egs/librispeech/ASR/zipformer/swoosh.pdf
name: swoosh.pdf name: swoosh-${{ matrix.python-version }}-${{ matrix.torch-version }}

View File

@ -61,5 +61,6 @@ jobs:
python3 -m torch.utils.collect_env python3 -m torch.utils.collect_env
python3 -m k2.version python3 -m k2.version
pip list
.github/scripts/yesno/ASR/run.sh .github/scripts/yesno/ASR/run.sh

View File

@ -42,7 +42,6 @@ for more details.
- [LibriSpeech][librispeech] - [LibriSpeech][librispeech]
- [Libriheavy][libriheavy] - [Libriheavy][libriheavy]
- [Multi-Dialect Broadcast News Arabic Speech Recognition][mgb2] - [Multi-Dialect Broadcast News Arabic Speech Recognition][mgb2]
- [PeopleSpeech][peoplespeech]
- [SPGISpeech][spgispeech] - [SPGISpeech][spgispeech]
- [Switchboard][swbd] - [Switchboard][swbd]
- [TIMIT][timit] - [TIMIT][timit]
@ -334,6 +333,7 @@ We provide a Colab notebook to test the pre-trained model: [![Open In Colab](htt
- [LJSpeech][ljspeech] - [LJSpeech][ljspeech]
- [VCTK][vctk] - [VCTK][vctk]
- [LibriTTS][libritts_tts]
### Supported Models ### Supported Models
@ -373,12 +373,13 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad
[commonvoice]: egs/commonvoice/ASR [commonvoice]: egs/commonvoice/ASR
[csj]: egs/csj/ASR [csj]: egs/csj/ASR
[libricss]: egs/libricss/SURT [libricss]: egs/libricss/SURT
[libritts_asr]: egs/libritts/ASR
[libriheavy]: egs/libriheavy/ASR [libriheavy]: egs/libriheavy/ASR
[mgb2]: egs/mgb2/ASR [mgb2]: egs/mgb2/ASR
[peoplespeech]: egs/peoplespeech/ASR
[spgispeech]: egs/spgispeech/ASR [spgispeech]: egs/spgispeech/ASR
[voxpopuli]: egs/voxpopuli/ASR [voxpopuli]: egs/voxpopuli/ASR
[xbmu-amdo31]: egs/xbmu-amdo31/ASR [xbmu-amdo31]: egs/xbmu-amdo31/ASR
[vctk]: egs/vctk/TTS [vctk]: egs/vctk/TTS
[ljspeech]: egs/ljspeech/TTS [ljspeech]: egs/ljspeech/TTS
[libritts_tts]: egs/libritts/TTS

View File

@ -0,0 +1,73 @@
FROM pytorch/pytorch:2.4.1-cuda11.8-cudnn9-devel
# python 3.10
ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.10
ARG K2_VERSION="1.24.4.dev20240905+cuda11.8.torch2.4.1"
ARG KALDIFEAT_VERSION="1.25.4.dev20240905+cuda11.8.torch2.4.1"
ARG TORCHAUDIO_VERSION="2.4.1+cu118"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
LABEL k2_version=${K2_VERSION}
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
LABEL github_repo="https://github.com/k2-fsa/icefall"
RUN apt-get update && \
apt-get install -y --no-install-recommends \
curl \
vim \
libssl-dev \
autoconf \
automake \
bzip2 \
ca-certificates \
ffmpeg \
g++ \
gfortran \
git \
libtool \
make \
patch \
sox \
subversion \
unzip \
valgrind \
wget \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/*
# Install dependencies
RUN pip install --no-cache-dir \
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
kaldi_native_io \
kaldialign \
kaldifst \
kaldilm \
sentencepiece>=0.1.96 \
tensorboard \
typeguard \
dill \
onnx \
onnxruntime \
onnxmltools \
onnxoptimizer \
onnxsim \
multi_quantization \
typeguard \
numpy \
pytest \
graphviz
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
pip install --no-cache-dir -r requirements.txt
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall

View File

@ -0,0 +1,73 @@
FROM pytorch/pytorch:2.4.1-cuda12.1-cudnn9-devel
# python 3.10
ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.10
ARG K2_VERSION="1.24.4.dev20240905+cuda12.1.torch2.4.1"
ARG KALDIFEAT_VERSION="1.25.4.dev20240905+cuda12.1.torch2.4.1"
ARG TORCHAUDIO_VERSION="2.4.1+cu121"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
LABEL k2_version=${K2_VERSION}
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
LABEL github_repo="https://github.com/k2-fsa/icefall"
RUN apt-get update && \
apt-get install -y --no-install-recommends \
curl \
vim \
libssl-dev \
autoconf \
automake \
bzip2 \
ca-certificates \
ffmpeg \
g++ \
gfortran \
git \
libtool \
make \
patch \
sox \
subversion \
unzip \
valgrind \
wget \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/*
# Install dependencies
RUN pip install --no-cache-dir \
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
kaldi_native_io \
kaldialign \
kaldifst \
kaldilm \
sentencepiece>=0.1.96 \
tensorboard \
typeguard \
dill \
onnx \
onnxruntime \
onnxmltools \
onnxoptimizer \
onnxsim \
multi_quantization \
typeguard \
numpy \
pytest \
graphviz
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
pip install --no-cache-dir -r requirements.txt
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall

View File

@ -0,0 +1,73 @@
FROM pytorch/pytorch:2.4.1-cuda12.4-cudnn9-devel
# python 3.10
ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.10
ARG K2_VERSION="1.24.4.dev20240905+cuda12.4.torch2.4.1"
ARG KALDIFEAT_VERSION="1.25.4.dev20240905+cuda12.4.torch2.4.1"
ARG TORCHAUDIO_VERSION="2.4.1+cu124"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
LABEL k2_version=${K2_VERSION}
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
LABEL github_repo="https://github.com/k2-fsa/icefall"
RUN apt-get update && \
apt-get install -y --no-install-recommends \
curl \
vim \
libssl-dev \
autoconf \
automake \
bzip2 \
ca-certificates \
ffmpeg \
g++ \
gfortran \
git \
libtool \
make \
patch \
sox \
subversion \
unzip \
valgrind \
wget \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/*
# Install dependencies
RUN pip install --no-cache-dir \
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
kaldi_native_io \
kaldialign \
kaldifst \
kaldilm \
sentencepiece>=0.1.96 \
tensorboard \
typeguard \
dill \
onnx \
onnxruntime \
onnxmltools \
onnxoptimizer \
onnxsim \
multi_quantization \
typeguard \
numpy \
pytest \
graphviz
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
pip install --no-cache-dir -r requirements.txt
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall

View File

@ -34,6 +34,12 @@ which will give you something like below:
.. code-block:: bash .. code-block:: bash
"torch2.4.1-cuda12.4"
"torch2.4.1-cuda12.1"
"torch2.4.1-cuda11.8"
"torch2.4.0-cuda12.4"
"torch2.4.0-cuda12.1"
"torch2.4.0-cuda11.8"
"torch2.3.1-cuda12.1" "torch2.3.1-cuda12.1"
"torch2.3.1-cuda11.8" "torch2.3.1-cuda11.8"
"torch2.2.2-cuda12.1" "torch2.2.2-cuda12.1"

View File

@ -87,7 +87,7 @@ fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare musan manifest" log "Stage 3: Prepare musan manifest"
# We assume that you have downloaded the musan corpus # We assume that you have downloaded the musan corpus
# to data/musan # to $dl_dir/musan
if [ ! -f data/manifests/.musan_manifests.done ]; then if [ ! -f data/manifests/.musan_manifests.done ]; then
log "It may take 6 minutes" log "It may take 6 minutes"
mkdir -p data/manifests mkdir -p data/manifests

View File

@ -58,7 +58,7 @@ if [ $stage -le 4 ]; then
# for train, we use smaller context and larger batches to speed-up processing # for train, we use smaller context and larger batches to speed-up processing
for JOB in $(seq $nj); do for JOB in $(seq $nj); do
gss enhance cuts $EXP_DIR/cuts_train.jsonl.gz \ gss enhance cuts $EXP_DIR/cuts_train.jsonl.gz \
$EXP_DIR/cuts_per_segment_train_split$nj/cuts_per_segment_train.JOB.jsonl.gz $EXP_DIR/enhanced \ $EXP_DIR/cuts_per_segment_train_split$nj/cuts_per_segment_train.$JOB.jsonl.gz $EXP_DIR/enhanced \
--bss-iterations 10 \ --bss-iterations 10 \
--context-duration 5.0 \ --context-duration 5.0 \
--use-garbage-class \ --use-garbage-class \
@ -77,7 +77,7 @@ if [ $stage -le 5 ]; then
for part in eval test; do for part in eval test; do
for JOB in $(seq $nj); do for JOB in $(seq $nj); do
gss enhance cuts $EXP_DIR/cuts_${part}.jsonl.gz \ gss enhance cuts $EXP_DIR/cuts_${part}.jsonl.gz \
$EXP_DIR/cuts_per_segment_${part}_split$nj/cuts_per_segment_${part}.JOB.jsonl.gz \ $EXP_DIR/cuts_per_segment_${part}_split$nj/cuts_per_segment_${part}.$JOB.jsonl.gz \
$EXP_DIR/enhanced \ $EXP_DIR/enhanced \
--bss-iterations 10 \ --bss-iterations 10 \
--context-duration 15.0 \ --context-duration 15.0 \

View File

@ -65,7 +65,7 @@ fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Prepare musan manifest" log "Stage 2: Prepare musan manifest"
# We assume that you have downloaded the musan corpus # We assume that you have downloaded the musan corpus
# to data/musan # to $dl_dir/musan
mkdir -p data/manifests mkdir -p data/manifests
lhotse prepare musan $dl_dir/musan data/manifests lhotse prepare musan $dl_dir/musan data/manifests
fi fi

View File

@ -82,7 +82,7 @@ class AlimeetingAsrDataModule:
group.add_argument( group.add_argument(
"--manifest-dir", "--manifest-dir",
type=Path, type=Path,
default=Path("data/manifests"), default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.", help="Path to directory with train/valid/test cuts.",
) )
group.add_argument( group.add_argument(
@ -327,9 +327,11 @@ class AlimeetingAsrDataModule:
def test_dataloaders(self, cuts: CutSet) -> DataLoader: def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset") logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset( test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) input_strategy=(
if self.args.on_the_fly_feats OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
else PrecomputedFeatures(), if self.args.on_the_fly_feats
else PrecomputedFeatures()
),
return_cuts=True, return_cuts=True,
) )
sampler = DynamicBucketingSampler( sampler = DynamicBucketingSampler(

View File

@ -58,7 +58,7 @@ if [ $stage -le 4 ]; then
# for train, we use smaller context and larger batches to speed-up processing # for train, we use smaller context and larger batches to speed-up processing
for JOB in $(seq $nj); do for JOB in $(seq $nj); do
gss enhance cuts $EXP_DIR/cuts_train.jsonl.gz \ gss enhance cuts $EXP_DIR/cuts_train.jsonl.gz \
$EXP_DIR/cuts_per_segment_train_split$nj/cuts_per_segment_train.JOB.jsonl.gz $EXP_DIR/enhanced \ $EXP_DIR/cuts_per_segment_train_split$nj/cuts_per_segment_train.$JOB.jsonl.gz $EXP_DIR/enhanced \
--bss-iterations 10 \ --bss-iterations 10 \
--context-duration 5.0 \ --context-duration 5.0 \
--use-garbage-class \ --use-garbage-class \
@ -77,7 +77,7 @@ if [ $stage -le 5 ]; then
for part in dev test; do for part in dev test; do
for JOB in $(seq $nj); do for JOB in $(seq $nj); do
gss enhance cuts $EXP_DIR/cuts_${part}.jsonl.gz \ gss enhance cuts $EXP_DIR/cuts_${part}.jsonl.gz \
$EXP_DIR/cuts_per_segment_${part}_split$nj/cuts_per_segment_${part}.JOB.jsonl.gz \ $EXP_DIR/cuts_per_segment_${part}_split$nj/cuts_per_segment_${part}.$JOB.jsonl.gz \
$EXP_DIR/enhanced \ $EXP_DIR/enhanced \
--bss-iterations 10 \ --bss-iterations 10 \
--context-duration 15.0 \ --context-duration 15.0 \

View File

@ -35,16 +35,40 @@ python zipformer/train.py \
--master-port 13455 --master-port 13455
``` ```
We recommend that you train the model with weighted sampler, as the model converges
faster with better performance:
| Model | mAP |
| ------ | ------- |
| Zipformer-AT, train with weighted sampler | 46.6 |
The evaluation command is: The evaluation command is:
```bash ```bash
python zipformer/evaluate.py \ export CUDA_VISIBLE_DEVICES="4,5,6,7"
--epoch 32 \ subset=full
--avg 8 \ weighted_sampler=1
--exp-dir zipformer/exp_at_as_full \ bucket_sampler=0
--max-duration 500 lr_epochs=15
python zipformer/train.py \
--world-size 4 \
--audioset-subset $subset \
--num-epochs 120 \
--start-epoch 1 \
--use-fp16 1 \
--num-events 527 \
--lr-epochs $lr_epochs \
--exp-dir zipformer/exp_AS_${subset}_weighted_sampler${weighted_sampler} \
--weighted-sampler $weighted_sampler \
--bucketing-sampler $bucket_sampler \
--max-duration 1000 \
--enable-musan True \
--master-port 13452
``` ```
The command for evaluation is the same. The pre-trained model can be downloaded from https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-M-weighted-sampler
#### small-scaled model, number of model parameters: 22125218, i.e., 22.13 M #### small-scaled model, number of model parameters: 22125218, i.e., 22.13 M
@ -92,4 +116,4 @@ python zipformer/evaluate.py \
--encoder-unmasked-dim 192,192,192,192,192,192 \ --encoder-unmasked-dim 192,192,192,192,192,192 \
--exp-dir zipformer/exp_small_at_as_full \ --exp-dir zipformer/exp_small_at_as_full \
--max-duration 500 --max-duration 500
``` ```

View File

@ -0,0 +1,73 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file generates the manifest and computes the fbank features for AudioSet
dataset. The generated manifests and features are stored in data/fbank.
"""
import argparse
import lhotse
from lhotse import load_manifest
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--input-manifest", type=str, default="data/fbank/cuts_audioset_full.jsonl.gz"
)
parser.add_argument(
"--output",
type=str,
required=True,
)
return parser
def main():
# Reference: https://github.com/YuanGongND/ast/blob/master/egs/audioset/gen_weight_file.py
parser = get_parser()
args = parser.parse_args()
cuts = load_manifest(args.input_manifest)
print(f"A total of {len(cuts)} cuts.")
label_count = [0] * 527 # a total of 527 classes
for c in cuts:
audio_event = c.supervisions[0].audio_event
labels = list(map(int, audio_event.split(";")))
for label in labels:
label_count[label] += 1
with open(args.output, "w") as f:
for c in cuts:
audio_event = c.supervisions[0].audio_event
labels = list(map(int, audio_event.split(";")))
weight = 0
for label in labels:
weight += 1000 / (label_count[label] + 0.01)
f.write(f"{c.id} {weight}\n")
if __name__ == "__main__":
main()

View File

@ -10,6 +10,7 @@ stage=-1
stop_stage=4 stop_stage=4
dl_dir=$PWD/download dl_dir=$PWD/download
fbank_dir=data/fbank
# we assume that you have your downloaded the AudioSet and placed # we assume that you have your downloaded the AudioSet and placed
# it under $dl_dir/audioset, the folder structure should look like # it under $dl_dir/audioset, the folder structure should look like
@ -49,7 +50,6 @@ fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Construct the audioset manifest and compute the fbank features for balanced set" log "Stage 0: Construct the audioset manifest and compute the fbank features for balanced set"
fbank_dir=data/fbank
if [! -e $fbank_dir/.balanced.done]; then if [! -e $fbank_dir/.balanced.done]; then
python local/generate_audioset_manifest.py \ python local/generate_audioset_manifest.py \
--dataset-dir $dl_dir/audioset \ --dataset-dir $dl_dir/audioset \
@ -102,3 +102,14 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
touch data/fbank/.musan.done touch data/fbank/.musan.done
fi fi
fi fi
# The following stages are required to do weighted-sampling training
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Prepare for weighted-sampling training"
if [ ! -e $fbank_dir/cuts_audioset_full.jsonl.gz ]; then
lhotse combine $fbank_dir/cuts_audioset_balanced.jsonl.gz $fbank_dir/cuts_audioset_unbalanced.jsonl.gz $fbank_dir/cuts_audioset_full.jsonl.gz
fi
python ./local/compute_weight.py \
--input-manifest $fbank_dir/cuts_audioset_full.jsonl.gz \
--output $fbank_dir/sampling_weights_full.txt
fi

View File

@ -31,6 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
PrecomputedFeatures, PrecomputedFeatures,
SimpleCutSampler, SimpleCutSampler,
SpecAugment, SpecAugment,
WeightedSimpleCutSampler,
) )
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples, AudioSamples,
@ -99,6 +100,20 @@ class AudioSetATDatamodule:
help="Maximum pooled recordings duration (seconds) in a " help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.", "single batch. You can reduce it if it causes CUDA OOM.",
) )
group.add_argument(
"--weighted-sampler",
type=str2bool,
default=False,
help="When enabled, samples are drawn from by their weights. "
"It cannot be used together with bucketing sampler",
)
group.add_argument(
"--num-samples",
type=int,
default=200000,
help="The number of samples to be drawn in each epoch. Only be used"
"for weighed sampler",
)
group.add_argument( group.add_argument(
"--bucketing-sampler", "--bucketing-sampler",
type=str2bool, type=str2bool,
@ -295,6 +310,9 @@ class AudioSetATDatamodule:
) )
if self.args.bucketing_sampler: if self.args.bucketing_sampler:
assert (
not self.args.weighted_sampler
), "weighted sampling is not supported in bucket sampler"
logging.info("Using DynamicBucketingSampler.") logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler( train_sampler = DynamicBucketingSampler(
cuts_train, cuts_train,
@ -304,13 +322,26 @@ class AudioSetATDatamodule:
drop_last=self.args.drop_last, drop_last=self.args.drop_last,
) )
else: else:
logging.info("Using SimpleCutSampler.") if self.args.weighted_sampler:
train_sampler = SimpleCutSampler( # assert self.args.audioset_subset == "full", "Only use weighted sampling for full audioset"
cuts_train, logging.info("Using weighted SimpleCutSampler")
max_duration=self.args.max_duration, weights = self.audioset_sampling_weights()
shuffle=self.args.shuffle, train_sampler = WeightedSimpleCutSampler(
drop_last=self.args.drop_last, cuts_train,
) weights,
num_samples=self.args.num_samples,
max_duration=self.args.max_duration,
shuffle=False, # do not support shuffle
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
drop_last=self.args.drop_last,
)
logging.info("About to create train dataloader") logging.info("About to create train dataloader")
if sampler_state_dict is not None: if sampler_state_dict is not None:
@ -373,11 +404,9 @@ class AudioSetATDatamodule:
def test_dataloaders(self, cuts: CutSet) -> DataLoader: def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset") logging.debug("About to create test dataset")
test = AudioTaggingDataset( test = AudioTaggingDataset(
input_strategy=( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) if self.args.on_the_fly_feats
if self.args.on_the_fly_feats else eval(self.args.input_strategy)(),
else eval(self.args.input_strategy)()
),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
sampler = DynamicBucketingSampler( sampler = DynamicBucketingSampler(
@ -397,21 +426,30 @@ class AudioSetATDatamodule:
@lru_cache() @lru_cache()
def audioset_train_cuts(self) -> CutSet: def audioset_train_cuts(self) -> CutSet:
logging.info("About to get the audioset training cuts.") logging.info("About to get the audioset training cuts.")
balanced_cuts = load_manifest_lazy( if not self.args.weighted_sampler:
self.args.manifest_dir / "cuts_audioset_balanced.jsonl.gz" balanced_cuts = load_manifest_lazy(
) self.args.manifest_dir / "cuts_audioset_balanced.jsonl.gz"
if self.args.audioset_subset == "full":
unbalanced_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_audioset_unbalanced.jsonl.gz"
)
cuts = CutSet.mux(
balanced_cuts,
unbalanced_cuts,
weights=[20000, 2000000],
stop_early=True,
) )
if self.args.audioset_subset == "full":
unbalanced_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_audioset_unbalanced.jsonl.gz"
)
cuts = CutSet.mux(
balanced_cuts,
unbalanced_cuts,
weights=[20000, 2000000],
stop_early=True,
)
else:
cuts = balanced_cuts
else: else:
cuts = balanced_cuts # assert self.args.audioset_subset == "full", "Only do weighted sampling for full AudioSet"
cuts = load_manifest(
self.args.manifest_dir
/ f"cuts_audioset_{self.args.audioset_subset}.jsonl.gz"
)
logging.info(f"Get {len(cuts)} cuts in total.")
return cuts return cuts
@lru_cache() @lru_cache()
@ -420,3 +458,22 @@ class AudioSetATDatamodule:
return load_manifest_lazy( return load_manifest_lazy(
self.args.manifest_dir / "cuts_audioset_eval.jsonl.gz" self.args.manifest_dir / "cuts_audioset_eval.jsonl.gz"
) )
@lru_cache()
def audioset_sampling_weights(self):
logging.info(
f"About to get the sampling weight for {self.args.audioset_subset} in AudioSet"
)
weights = []
with open(
self.args.manifest_dir / f"sample_weights_{self.args.audioset_subset}.txt",
"r",
) as f:
while True:
line = f.readline()
if not line:
break
weight = float(line.split()[1])
weights.append(weight)
logging.info(f"Get the sampling weight for {len(weights)} cuts")
return weights

View File

@ -789,12 +789,14 @@ def train_one_epoch(
rank=0, rank=0,
) )
num_samples = 0
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx % 10 == 0: if batch_idx % 10 == 0:
set_batch_count(model, get_adjusted_batch_count(params)) set_batch_count(model, get_adjusted_batch_count(params))
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = batch["inputs"].size(0) batch_size = batch["inputs"].size(0)
num_samples += batch_size
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.cuda.amp.autocast(enabled=params.use_fp16):
@ -919,6 +921,12 @@ def train_one_epoch(
tb_writer, "train/valid_", params.batch_idx_train tb_writer, "train/valid_", params.batch_idx_train
) )
if num_samples > params.num_samples:
logging.info(
f"Number of training samples exceeds {params.num_samples} in this epoch, move on to next epoch"
)
break
loss_value = tot_loss["loss"] / tot_loss["frames"] loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value params.train_loss = loss_value
if params.train_loss < params.best_train_loss: if params.train_loss < params.best_train_loss:
@ -1032,7 +1040,8 @@ def run(rank, world_size, args):
return True return True
train_cuts = train_cuts.filter(remove_short_and_long_utt) if not params.weighted_sampler:
train_cuts = train_cuts.filter(remove_short_and_long_utt)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint # We only load the sampler's state dict when it loads a checkpoint

6
egs/baker_zh/TTS/.gitignore vendored Normal file
View File

@ -0,0 +1,6 @@
path.sh
*.onnx
*.wav
generator_v1
generator_v2
generator_v3

146
egs/baker_zh/TTS/README.md Normal file
View File

@ -0,0 +1,146 @@
# Introduction
It is for the dataset from
https://en.data-baker.com/datasets/freeDatasets/
The dataset contains 10000 Chinese sentences of a native Chinese female speaker,
which is about 12 hours.
**Note**: The dataset is for non-commercial use only.
# matcha
[./matcha](./matcha) contains the code for training [Matcha-TTS](https://github.com/shivammehta25/Matcha-TTS)
Checkpoints and training logs can be found [here](https://huggingface.co/csukuangfj/icefall-tts-baker-matcha-zh-2024-12-27).
The pull-request for this recipe can be found at <https://github.com/k2-fsa/icefall/pull/1849>
The training command is given below:
```bash
python3 ./matcha/train.py \
--exp-dir ./matcha/exp-1/ \
--num-workers 4 \
--world-size 1 \
--num-epochs 2000 \
--max-duration 1200 \
--bucketing-sampler 1 \
--start-epoch 1
```
To inference, use:
```bash
# Download Hifigan vocoder. We use Hifigan v2 below. You can select from v1, v2, or v3
wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2
python3 ./matcha/infer.py \
--epoch 2000 \
--exp-dir ./matcha/exp-1 \
--vocoder ./generator_v2 \
--tokens ./data/tokens.txt \
--cmvn ./data/fbank/cmvn.json \
--input-text "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。" \
--output-wav ./generated.wav
```
```bash
soxi ./generated.wav
```
prints:
```
Input File : './generated.wav'
Channels : 1
Sample Rate : 22050
Precision : 16-bit
Duration : 00:00:17.31 = 381696 samples ~ 1298.29 CDDA sectors
File Size : 763k
Bit Rate : 353k
Sample Encoding: 16-bit Signed Integer PCM
```
https://github.com/user-attachments/assets/88d4e88f-ebc4-4f32-b216-16d46b966024
To export the checkpoint to onnx:
```bash
python3 ./matcha/export_onnx.py \
--exp-dir ./matcha/exp-1 \
--epoch 2000 \
--tokens ./data/tokens.txt \
--cmvn ./data/fbank/cmvn.json
```
The above command generates the following files:
```
-rw-r--r-- 1 kuangfangjun root 72M Dec 27 18:53 model-steps-2.onnx
-rw-r--r-- 1 kuangfangjun root 73M Dec 27 18:54 model-steps-3.onnx
-rw-r--r-- 1 kuangfangjun root 73M Dec 27 18:54 model-steps-4.onnx
-rw-r--r-- 1 kuangfangjun root 74M Dec 27 18:55 model-steps-5.onnx
-rw-r--r-- 1 kuangfangjun root 74M Dec 27 18:57 model-steps-6.onnx
```
where the 2 in `model-steps-2.onnx` means it uses 2 steps for the ODE solver.
**HINT**: If you get the following error while running `export_onnx.py`:
```
torch.onnx.errors.UnsupportedOperatorError: Exporting the operator
'aten::scaled_dot_product_attention' to ONNX opset version 14 is not supported.
```
please use `torch>=2.2.0`.
To export the Hifigan vocoder to onnx, please use:
```bash
wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1
wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2
wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v3
python3 ./matcha/export_onnx_hifigan.py
```
The above command generates 3 files:
- hifigan_v1.onnx
- hifigan_v2.onnx
- hifigan_v3.onnx
**HINT**: You can download pre-exported hifigan ONNX models from
<https://github.com/k2-fsa/sherpa-onnx/releases/tag/vocoder-models>
To use the generated onnx files to generate speech from text, please run:
```bash
# First, generate ./lexicon.txt
python3 ./matcha/generate_lexicon.py
python3 ./matcha/onnx_pretrained.py \
--acoustic-model ./model-steps-4.onnx \
--vocoder ./hifigan_v2.onnx \
--tokens ./data/tokens.txt \
--lexicon ./lexicon.txt \
--input-text "在一个阳光明媚的夏天,小马、小羊和小狗它们一块儿在广阔的草地上,嬉戏玩耍,这时小猴来了,还带着它心爱的足球活蹦乱跳地跑前、跑后教小马、小羊、小狗踢足球。" \
--output-wav ./1.wav
```
```bash
soxi ./1.wav
Input File : './1.wav'
Channels : 1
Sample Rate : 22050
Precision : 16-bit
Duration : 00:00:16.37 = 360960 samples ~ 1227.76 CDDA sectors
File Size : 722k
Bit Rate : 353k
Sample Encoding: 16-bit Signed Integer PCM
```
https://github.com/user-attachments/assets/578d04bb-fee8-47e5-9984-a868dcce610e

View File

@ -0,0 +1 @@
../matcha/audio.py

View File

@ -0,0 +1,110 @@
#!/usr/bin/env python3
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the baker-zh dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
import os
from pathlib import Path
import torch
from fbank import MatchaFbank, MatchaFbankConfig
from lhotse import CutSet, LilcomChunkyWriter, load_manifest
from lhotse.audio import RecordingSet
from lhotse.supervision import SupervisionSet
from icefall.utils import get_executor
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--num-jobs",
type=int,
default=4,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
)
return parser
def compute_fbank_baker_zh(num_jobs: int):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
if num_jobs < 1:
num_jobs = os.cpu_count()
logging.info(f"num_jobs: {num_jobs}")
logging.info(f"src_dir: {src_dir}")
logging.info(f"output_dir: {output_dir}")
config = MatchaFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=22050,
hop_length=256,
win_length=1024,
f_min=0,
f_max=8000,
)
prefix = "baker_zh"
suffix = "jsonl.gz"
extractor = MatchaFbank(config)
with get_executor() as ex: # Initialize the executor only once.
cuts_filename = f"{prefix}_cuts.{suffix}"
logging.info(f"Processing {cuts_filename}")
cut_set = load_manifest(src_dir / cuts_filename).resample(22050)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats",
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(output_dir / cuts_filename)
if __name__ == "__main__":
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_parser().parse_args()
compute_fbank_baker_zh(args.num_jobs)

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/local/compute_fbank_statistics.py

View File

@ -0,0 +1,121 @@
#!/usr/bin/env python3
import argparse
import re
from typing import List
import jieba
from lhotse import load_manifest
from pypinyin import Style, lazy_pinyin, load_phrases_dict
load_phrases_dict(
{
"行长": [["hang2"], ["zhang3"]],
"银行行长": [["yin2"], ["hang2"], ["hang2"], ["zhang3"]],
}
)
whiter_space_re = re.compile(r"\s+")
punctuations_re = [
(re.compile(x[0], re.IGNORECASE), x[1])
for x in [
("", ","),
("", "."),
("", "!"),
("", "?"),
("", '"'),
("", '"'),
("", "'"),
("", "'"),
("", ":"),
("", ","),
("", ""),
("", ""),
]
]
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--in-file",
type=str,
required=True,
help="Input cutset.",
)
parser.add_argument(
"--out-file",
type=str,
required=True,
help="Output cutset.",
)
return parser
def normalize_white_spaces(text):
return whiter_space_re.sub(" ", text)
def normalize_punctuations(text):
for regex, replacement in punctuations_re:
text = re.sub(regex, replacement, text)
return text
def split_text(text: str) -> List[str]:
"""
Example input: '你好呀You are 一个好人。 去银行存钱How about you?'
Example output: ['你好', '', ',', 'you are', '一个', '好人', '.', '', '银行', '存钱', '?', 'how about you', '?']
"""
text = text.lower()
text = normalize_white_spaces(text)
text = normalize_punctuations(text)
ans = []
for seg in jieba.cut(text):
if seg in ",.!?:\"'":
ans.append(seg)
elif seg == " " and len(ans) > 0:
if ord("a") <= ord(ans[-1][-1]) <= ord("z"):
ans[-1] += seg
elif ord("a") <= ord(seg[0]) <= ord("z"):
if len(ans) == 0:
ans.append(seg)
continue
if ans[-1][-1] == " ":
ans[-1] += seg
continue
ans.append(seg)
else:
ans.append(seg)
ans = [s.strip() for s in ans]
return ans
def main():
args = get_parser().parse_args()
cuts = load_manifest(args.in_file)
for c in cuts:
assert len(c.supervisions) == 1, (len(c.supervisions), c.supervisions)
text = c.supervisions[0].normalized_text
text_list = split_text(text)
tokens = lazy_pinyin(text_list, style=Style.TONE3, tone_sandhi=True)
c.tokens = tokens
cuts.to_file(args.out_file)
print(f"saved to {args.out_file}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../matcha/fbank.py

View File

@ -0,0 +1,85 @@
#!/usr/bin/env python3
"""
This file generates the file tokens.txt.
Usage:
python3 ./local/generate_tokens.py > data/tokens.txt
"""
import argparse
from typing import List
import jieba
from pypinyin import Style, lazy_pinyin, pinyin_dict
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--tokens",
type=str,
required=True,
help="Path to to save tokens.txt.",
)
return parser
def generate_token_list() -> List[str]:
token_set = set()
word_dict = pinyin_dict.pinyin_dict
i = 0
for key in word_dict:
if not (0x4E00 <= key <= 0x9FFF):
continue
w = chr(key)
t = lazy_pinyin(w, style=Style.TONE3, tone_sandhi=True)[0]
token_set.add(t)
no_digit = set()
for t in token_set:
if t[-1] not in "1234":
no_digit.add(t)
else:
no_digit.add(t[:-1])
no_digit.add("dei")
no_digit.add("tou")
no_digit.add("dia")
for t in no_digit:
token_set.add(t)
for i in range(1, 5):
token_set.add(f"{t}{i}")
ans = list(token_set)
ans.sort()
punctuations = list(",.!?:\"'")
ans = punctuations + ans
# use ID 0 for blank
# Use ID 1 of _ for padding
ans.insert(0, " ")
ans.insert(1, "_") #
return ans
def main():
args = get_parser().parse_args()
token_list = generate_token_list()
with open(args.tokens, "w", encoding="utf-8") as f:
for indx, token in enumerate(token_list):
f.write(f"{token} {indx}\n")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,70 @@
#!/usr/bin/env python3
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script checks the following assumptions of the generated manifest:
- Single supervision per cut
We will add more checks later if needed.
Usage example:
python3 ./local/validate_manifest.py \
./data/spectrogram/baker_zh_cuts_all.jsonl.gz
"""
import argparse
import logging
from pathlib import Path
from lhotse import CutSet, load_manifest_lazy
from lhotse.dataset.speech_synthesis import validate_for_tts
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"manifest",
type=Path,
help="Path to the manifest file",
)
return parser.parse_args()
def main():
args = get_args()
manifest = args.manifest
logging.info(f"Validating {manifest}")
assert manifest.is_file(), f"{manifest} does not exist"
cut_set = load_manifest_lazy(manifest)
assert isinstance(cut_set, CutSet), type(cut_set)
validate_for_tts(cut_set)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/matcha/audio.py

View File

@ -0,0 +1,207 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
"""
This script exports a Matcha-TTS model to ONNX.
Note that the model outputs fbank. You need to use a vocoder to convert
it to audio. See also ./export_onnx_hifigan.py
python3 ./matcha/export_onnx.py \
--exp-dir ./matcha/exp-1 \
--epoch 2000 \
--tokens ./data/tokens.txt \
--cmvn ./data/fbank/cmvn.json
"""
import argparse
import json
import logging
from pathlib import Path
from typing import Any, Dict
import onnx
import torch
from tokenizer import Tokenizer
from train import get_model, get_params
from icefall.checkpoint import load_checkpoint
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=2000,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
)
parser.add_argument(
"--exp-dir",
type=Path,
default="matcha/exp-new-3",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--tokens",
type=Path,
default="data/tokens.txt",
)
parser.add_argument(
"--cmvn",
type=str,
default="data/fbank/cmvn.json",
help="""Path to vocabulary.""",
)
return parser
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
while len(model.metadata_props):
model.metadata_props.pop()
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
onnx.save(model, filename)
class ModelWrapper(torch.nn.Module):
def __init__(self, model, num_steps: int = 5):
super().__init__()
self.model = model
self.num_steps = num_steps
def forward(
self,
x: torch.Tensor,
x_lengths: torch.Tensor,
noise_scale: torch.Tensor,
length_scale: torch.Tensor,
) -> torch.Tensor:
"""
Args: :
x: (batch_size, num_tokens), torch.int64
x_lengths: (batch_size,), torch.int64
noise_scale: (1,), torch.float32
length_scale (1,), torch.float32
Returns:
audio: (batch_size, num_samples)
"""
mel = self.model.synthesise(
x=x,
x_lengths=x_lengths,
n_timesteps=self.num_steps,
temperature=noise_scale,
length_scale=length_scale,
)["mel"]
# mel: (batch_size, feat_dim, num_frames)
return mel
@torch.inference_mode()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
tokenizer = Tokenizer(params.tokens)
params.pad_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
params.model_args.n_vocab = params.vocab_size
with open(params.cmvn) as f:
stats = json.load(f)
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
params.data_args.data_statistics.mel_std = stats["fbank_std"]
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
params.model_args.data_statistics.mel_std = stats["fbank_std"]
logging.info(params)
logging.info("About to create model")
model = get_model(params)
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
for num_steps in [2, 3, 4, 5, 6]:
logging.info(f"num_steps: {num_steps}")
wrapper = ModelWrapper(model, num_steps=num_steps)
wrapper.eval()
# Use a large value so the rotary position embedding in the text
# encoder has a large initial length
x = torch.ones(1, 1000, dtype=torch.int64)
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
noise_scale = torch.tensor([1.0])
length_scale = torch.tensor([1.0])
opset_version = 14
filename = f"model-steps-{num_steps}.onnx"
torch.onnx.export(
wrapper,
(x, x_lengths, noise_scale, length_scale),
filename,
opset_version=opset_version,
input_names=["x", "x_length", "noise_scale", "length_scale"],
output_names=["mel"],
dynamic_axes={
"x": {0: "N", 1: "L"},
"x_length": {0: "N"},
"mel": {0: "N", 2: "L"},
},
)
meta_data = {
"model_type": "matcha-tts",
"language": "Chinese",
"has_espeak": 0,
"n_speakers": 1,
"jieba": 1,
"sample_rate": 22050,
"version": 1,
"pad_id": params.pad_id,
"model_author": "icefall",
"maintainer": "k2-fsa",
"dataset": "baker-zh",
"use_eos_bos": 0,
"dataset_url": "https://www.data-baker.com/open_source.html",
"dataset_comment": "The dataset is for non-commercial use only.",
"num_ode_steps": num_steps,
}
add_meta_data(filename=filename, meta_data=meta_data)
print(meta_data)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/matcha/export_onnx_hifigan.py

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/matcha/fbank.py

View File

@ -0,0 +1,42 @@
#!/usr/bin/env python3
import jieba
from pypinyin import Style, lazy_pinyin, load_phrases_dict, phrases_dict, pinyin_dict
from tokenizer import Tokenizer
load_phrases_dict(
{
"行长": [["hang2"], ["zhang3"]],
"银行行长": [["yin2"], ["hang2"], ["hang2"], ["zhang3"]],
}
)
def main():
filename = "lexicon.txt"
tokens = "./data/tokens.txt"
tokenizer = Tokenizer(tokens)
word_dict = pinyin_dict.pinyin_dict
phrases = phrases_dict.phrases_dict
i = 0
with open(filename, "w", encoding="utf-8") as f:
for key in word_dict:
if not (0x4E00 <= key <= 0x9FFF):
continue
w = chr(key)
tokens = lazy_pinyin(w, style=Style.TONE3, tone_sandhi=True)[0]
f.write(f"{w} {tokens}\n")
for key in phrases:
tokens = lazy_pinyin(key, style=Style.TONE3, tone_sandhi=True)
tokens = " ".join(tokens)
f.write(f"{key} {tokens}\n")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/matcha/hifigan

342
egs/baker_zh/TTS/matcha/infer.py Executable file
View File

@ -0,0 +1,342 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
"""
python3 ./matcha/infer.py \
--epoch 2000 \
--exp-dir ./matcha/exp-1 \
--vocoder ./generator_v2 \
--tokens ./data/tokens.txt \
--cmvn ./data/fbank/cmvn.json \
--input-text "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。" \
--output-wav ./generated.wav
"""
import argparse
import datetime as dt
import json
import logging
from pathlib import Path
import soundfile as sf
import torch
import torch.nn as nn
from hifigan.config import v1, v2, v3
from hifigan.denoiser import Denoiser
from hifigan.models import Generator as HiFiGAN
from local.convert_text_to_tokens import split_text
from pypinyin import Style, lazy_pinyin
from tokenizer import Tokenizer
from train import get_model, get_params
from tts_datamodule import BakerZhTtsDataModule
from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict, setup_logger
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=4000,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
)
parser.add_argument(
"--exp-dir",
type=Path,
default="matcha/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--vocoder",
type=Path,
default="./generator_v1",
help="Path to the vocoder",
)
parser.add_argument(
"--tokens",
type=Path,
default="data/tokens.txt",
)
parser.add_argument(
"--cmvn",
type=str,
default="data/fbank/cmvn.json",
help="""Path to vocabulary.""",
)
# The following arguments are used for inference on single text
parser.add_argument(
"--input-text",
type=str,
required=False,
help="The text to generate speech for",
)
parser.add_argument(
"--output-wav",
type=str,
required=False,
help="The filename of the wave to save the generated speech",
)
parser.add_argument(
"--sampling-rate",
type=int,
default=22050,
help="The sampling rate of the generated speech (default: 22050 for baker_zh)",
)
return parser
def load_vocoder(checkpoint_path: Path) -> nn.Module:
checkpoint_path = str(checkpoint_path)
if checkpoint_path.endswith("v1"):
h = AttributeDict(v1)
elif checkpoint_path.endswith("v2"):
h = AttributeDict(v2)
elif checkpoint_path.endswith("v3"):
h = AttributeDict(v3)
else:
raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}")
hifigan = HiFiGAN(h).to("cpu")
hifigan.load_state_dict(
torch.load(checkpoint_path, map_location="cpu")["generator"]
)
_ = hifigan.eval()
hifigan.remove_weight_norm()
return hifigan
def to_waveform(
mel: torch.Tensor, vocoder: nn.Module, denoiser: nn.Module
) -> torch.Tensor:
audio = vocoder(mel).clamp(-1, 1)
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
return audio.squeeze()
def process_text(text: str, tokenizer: Tokenizer, device: str = "cpu") -> dict:
text = split_text(text)
tokens = lazy_pinyin(text, style=Style.TONE3, tone_sandhi=True)
x = tokenizer.texts_to_token_ids([tokens])
x = torch.tensor(x, dtype=torch.long, device=device)
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device)
return {"x_orig": text, "x": x, "x_lengths": x_lengths}
def synthesize(
model: nn.Module,
tokenizer: Tokenizer,
n_timesteps: int,
text: str,
length_scale: float,
temperature: float,
device: str = "cpu",
spks=None,
) -> dict:
text_processed = process_text(text=text, tokenizer=tokenizer, device=device)
start_t = dt.datetime.now()
output = model.synthesise(
text_processed["x"],
text_processed["x_lengths"],
n_timesteps=n_timesteps,
temperature=temperature,
spks=spks,
length_scale=length_scale,
)
# merge everything to one dict
output.update({"start_t": start_t, **text_processed})
return output
def infer_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
vocoder: nn.Module,
denoiser: nn.Module,
tokenizer: Tokenizer,
) -> None:
"""Decode dataset.
The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
tokenizer:
Used to convert text to phonemes.
"""
device = next(model.parameters()).device
num_cuts = 0
log_interval = 5
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
for batch_idx, batch in enumerate(dl):
batch_size = len(batch["tokens"])
texts = [c.supervisions[0].normalized_text for c in batch["cut"]]
audio = batch["audio"]
audio_lens = batch["audio_lens"].tolist()
cut_ids = [cut.id for cut in batch["cut"]]
for i in range(batch_size):
output = synthesize(
model=model,
tokenizer=tokenizer,
n_timesteps=params.n_timesteps,
text=texts[i],
length_scale=params.length_scale,
temperature=params.temperature,
device=device,
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
sf.write(
file=params.save_wav_dir / f"{cut_ids[i]}_pred.wav",
data=output["waveform"],
samplerate=params.data_args.sampling_rate,
subtype="PCM_16",
)
sf.write(
file=params.save_wav_dir / f"{cut_ids[i]}_gt.wav",
data=audio[i].numpy(),
samplerate=params.data_args.sampling_rate,
subtype="PCM_16",
)
num_cuts += batch_size
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
@torch.inference_mode()
def main():
parser = get_parser()
BakerZhTtsDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.suffix = f"epoch-{params.epoch}"
params.res_dir = params.exp_dir / "infer" / params.suffix
params.save_wav_dir = params.res_dir / "wav"
params.save_wav_dir.mkdir(parents=True, exist_ok=True)
setup_logger(f"{params.res_dir}/log-infer-{params.suffix}")
logging.info("Infer started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
tokenizer = Tokenizer(params.tokens)
params.vocab_size = tokenizer.vocab_size
params.model_args.n_vocab = params.vocab_size
with open(params.cmvn) as f:
stats = json.load(f)
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
params.data_args.data_statistics.mel_std = stats["fbank_std"]
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
params.model_args.data_statistics.mel_std = stats["fbank_std"]
# Number of ODE Solver steps
params.n_timesteps = 2
# Changes to the speaking rate
params.length_scale = 1.0
# Sampling temperature
params.temperature = 0.667
logging.info(params)
logging.info("About to create model")
model = get_model(params)
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device)
model.eval()
# we need cut ids to organize tts results.
args.return_cuts = True
baker_zh = BakerZhTtsDataModule(args)
test_cuts = baker_zh.test_cuts()
test_dl = baker_zh.test_dataloaders(test_cuts)
if not Path(params.vocoder).is_file():
raise ValueError(f"{params.vocoder} does not exist")
vocoder = load_vocoder(params.vocoder)
vocoder.to(device)
denoiser = Denoiser(vocoder, mode="zeros")
denoiser.to(device)
if params.input_text is not None and params.output_wav is not None:
logging.info("Synthesizing a single text")
output = synthesize(
model=model,
tokenizer=tokenizer,
n_timesteps=params.n_timesteps,
text=params.input_text,
length_scale=params.length_scale,
temperature=params.temperature,
device=device,
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
sf.write(
file=params.output_wav,
data=output["waveform"],
samplerate=params.sampling_rate,
subtype="PCM_16",
)
else:
logging.info("Decoding the test set")
infer_dataset(
dl=test_dl,
params=params,
model=model,
vocoder=vocoder,
denoiser=denoiser,
tokenizer=tokenizer,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/matcha/model.py

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/matcha/models

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/matcha/monotonic_align

View File

@ -0,0 +1,316 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
"""
python3 ./matcha/onnx_pretrained.py \
--acoustic-model ./model-steps-4.onnx \
--vocoder ./hifigan_v2.onnx \
--tokens ./data/tokens.txt \
--lexicon ./lexicon.txt \
--input-text "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。" \
--output-wav ./b.wav
"""
import argparse
import datetime as dt
import logging
import re
from typing import Dict, List
import jieba
import onnxruntime as ort
import soundfile as sf
import torch
from infer import load_vocoder
from utils import intersperse
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--acoustic-model",
type=str,
required=True,
help="Path to the acoustic model",
)
parser.add_argument(
"--tokens",
type=str,
required=True,
help="Path to the tokens.txt",
)
parser.add_argument(
"--lexicon",
type=str,
required=True,
help="Path to the lexicon.txt",
)
parser.add_argument(
"--vocoder",
type=str,
required=True,
help="Path to the vocoder",
)
parser.add_argument(
"--input-text",
type=str,
required=True,
help="The text to generate speech for",
)
parser.add_argument(
"--output-wav",
type=str,
required=True,
help="The filename of the wave to save the generated speech",
)
return parser
class OnnxHifiGANModel:
def __init__(
self,
filename: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.model = ort.InferenceSession(
filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
for i in self.model.get_inputs():
print(i)
print("-----")
for i in self.model.get_outputs():
print(i)
def __call__(self, x: torch.tensor):
assert x.ndim == 3, x.shape
assert x.shape[0] == 1, x.shape
audio = self.model.run(
[self.model.get_outputs()[0].name],
{
self.model.get_inputs()[0].name: x.numpy(),
},
)[0]
# audio: (batch_size, num_samples)
return torch.from_numpy(audio)
class OnnxModel:
def __init__(
self,
filename: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 2
self.session_opts = session_opts
self.model = ort.InferenceSession(
filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
metadata = self.model.get_modelmeta().custom_metadata_map
self.sample_rate = int(metadata["sample_rate"])
for i in self.model.get_inputs():
print(i)
print("-----")
for i in self.model.get_outputs():
print(i)
def __call__(self, x: torch.tensor):
assert x.ndim == 2, x.shape
assert x.shape[0] == 1, x.shape
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
print("x_lengths", x_lengths)
print("x", x.shape)
noise_scale = torch.tensor([1.0], dtype=torch.float32)
length_scale = torch.tensor([1.0], dtype=torch.float32)
mel = self.model.run(
[self.model.get_outputs()[0].name],
{
self.model.get_inputs()[0].name: x.numpy(),
self.model.get_inputs()[1].name: x_lengths.numpy(),
self.model.get_inputs()[2].name: noise_scale.numpy(),
self.model.get_inputs()[3].name: length_scale.numpy(),
},
)[0]
# mel: (batch_size, feat_dim, num_frames)
return torch.from_numpy(mel)
def read_tokens(filename: str) -> Dict[str, int]:
token2id = dict()
with open(filename, encoding="utf-8") as f:
for line in f.readlines():
info = line.rstrip().split()
if len(info) == 1:
# case of space
token = " "
idx = int(info[0])
else:
token, idx = info[0], int(info[1])
assert token not in token2id, token
token2id[token] = idx
return token2id
def read_lexicon(filename: str) -> Dict[str, List[str]]:
word2token = dict()
with open(filename, encoding="utf-8") as f:
for line in f.readlines():
info = line.rstrip().split()
w = info[0]
tokens = info[1:]
word2token[w] = tokens
return word2token
def convert_word_to_tokens(word2tokens: Dict[str, List[str]], word: str) -> List[str]:
if word in word2tokens:
return word2tokens[word]
if len(word) == 1:
return []
ans = []
for w in word:
t = convert_word_to_tokens(word2tokens, w)
ans.extend(t)
return ans
def normalize_text(text):
whiter_space_re = re.compile(r"\s+")
punctuations_re = [
(re.compile(x[0], re.IGNORECASE), x[1])
for x in [
("", ","),
("", "."),
("", "!"),
("", "?"),
("", '"'),
("", '"'),
("", "'"),
("", "'"),
("", ":"),
("", ","),
]
]
for regex, replacement in punctuations_re:
text = re.sub(regex, replacement, text)
return text
@torch.no_grad()
def main():
params = get_parser().parse_args()
logging.info(vars(params))
token2id = read_tokens(params.tokens)
word2tokens = read_lexicon(params.lexicon)
text = normalize_text(params.input_text)
seg = jieba.cut(text)
tokens = []
for s in seg:
if s in token2id:
tokens.append(s)
continue
t = convert_word_to_tokens(word2tokens, s)
if t:
tokens.extend(t)
model = OnnxModel(params.acoustic_model)
vocoder = OnnxHifiGANModel(params.vocoder)
x = []
for t in tokens:
if t in token2id:
x.append(token2id[t])
x = intersperse(x, item=token2id["_"])
x = torch.tensor(x, dtype=torch.int64).unsqueeze(0)
start_t = dt.datetime.now()
mel = model(x)
end_t = dt.datetime.now()
start_t2 = dt.datetime.now()
audio = vocoder(mel)
end_t2 = dt.datetime.now()
print("audio", audio.shape) # (1, 1, num_samples)
audio = audio.squeeze()
sample_rate = model.sample_rate
t = (end_t - start_t).total_seconds()
t2 = (end_t2 - start_t2).total_seconds()
rtf_am = t * sample_rate / audio.shape[-1]
rtf_vocoder = t2 * sample_rate / audio.shape[-1]
print("RTF for acoustic model ", rtf_am)
print("RTF for vocoder", rtf_vocoder)
# skip denoiser
sf.write(params.output_wav, audio, sample_rate, "PCM_16")
logging.info(f"Saved to {params.output_wav}")
if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()
"""
|HifiGAN |RTF |#Parameters (M)|
|----------|-----|---------------|
|v1 |0.818| 13.926 |
|v2 |0.101| 0.925 |
|v3 |0.118| 1.462 |
|Num steps|Acoustic Model RTF|
|---------|------------------|
| 2 | 0.039 |
| 3 | 0.047 |
| 4 | 0.071 |
| 5 | 0.076 |
| 6 | 0.103 |
"""

View File

@ -0,0 +1,119 @@
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
import logging
from typing import Dict, List
import tacotron_cleaner.cleaners
try:
from piper_phonemize import phonemize_espeak
except Exception as ex:
raise RuntimeError(
f"{ex}\nPlease run\n"
"pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html"
)
from utils import intersperse
# This tokenizer supports both English and Chinese.
# We assume you have used
# ../local/convert_text_to_tokens.py
# to process your text
class Tokenizer(object):
def __init__(self, tokens: str):
"""
Args:
tokens: the file that maps tokens to ids
"""
# Parse token file
self.token2id: Dict[str, int] = {}
with open(tokens, "r", encoding="utf-8") as f:
for line in f.readlines():
info = line.rstrip().split()
if len(info) == 1:
# case of space
token = " "
id = int(info[0])
else:
token, id = info[0], int(info[1])
assert token not in self.token2id, token
self.token2id[token] = id
# Refer to https://github.com/rhasspy/piper/blob/master/TRAINING.md
self.pad_id = self.token2id["_"] # padding
self.space_id = self.token2id[" "] # word separator (whitespace)
self.vocab_size = len(self.token2id)
def texts_to_token_ids(
self,
sentence_list: List[List[str]],
intersperse_blank: bool = True,
lang: str = "en-us",
) -> List[List[int]]:
"""
Args:
sentence_list:
A list of sentences.
intersperse_blank:
Whether to intersperse blanks in the token sequence.
lang:
Language argument passed to phonemize_espeak().
Returns:
Return a list of token id list [utterance][token_id]
"""
token_ids_list = []
for sentence in sentence_list:
tokens_list = []
for word in sentence:
if word in self.token2id:
tokens_list.append(word)
continue
tmp_tokens_list = phonemize_espeak(word, lang)
for t in tmp_tokens_list:
tokens_list.extend(t)
token_ids = []
for t in tokens_list:
if t not in self.token2id:
logging.warning(f"Skip OOV {t} {sentence}")
continue
if t == " " and len(token_ids) > 0 and token_ids[-1] == self.space_id:
continue
token_ids.append(self.token2id[t])
if intersperse_blank:
token_ids = intersperse(token_ids, self.pad_id)
token_ids_list.append(token_ids)
return token_ids_list
def test_tokenizer():
import jieba
from pypinyin import Style, lazy_pinyin
tokenizer = Tokenizer("data/tokens.txt")
text1 = "今天is Monday, tomorrow is 星期二"
text2 = "你好吗? 我很好, how about you?"
text1 = list(jieba.cut(text1))
text2 = list(jieba.cut(text2))
tokens1 = lazy_pinyin(text1, style=Style.TONE3, tone_sandhi=True)
tokens2 = lazy_pinyin(text2, style=Style.TONE3, tone_sandhi=True)
print(tokens1)
print(tokens2)
ids = tokenizer.texts_to_token_ids([tokens1, tokens2])
print(ids)
if __name__ == "__main__":
test_tokenizer()

717
egs/baker_zh/TTS/matcha/train.py Executable file
View File

@ -0,0 +1,717 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
import argparse
import json
import logging
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Union
import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from lhotse.utils import fix_random_seed
from model import fix_len_compatibility
from models.matcha_tts import MatchaTTS
from tokenizer import Tokenizer
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
from tts_datamodule import BakerZhTtsDataModule
from utils import MetricsTracker
from icefall.checkpoint import load_checkpoint, save_checkpoint
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12335,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=1000,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=1,
help="""Resume training from this epoch. It should be positive.
If larger than 1, it will load checkpoint from
exp-dir/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=Path,
default="matcha/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--tokens",
type=str,
default="data/tokens.txt",
help="""Path to vocabulary.""",
)
parser.add_argument(
"--cmvn",
type=str,
default="data/fbank/cmvn.json",
help="""Path to vocabulary.""",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
parser.add_argument(
"--save-every-n",
type=int,
default=10,
help="""Save checkpoint after processing this number of epochs"
periodically. We save checkpoint to exp-dir/ whenever
params.cur_epoch % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'.
Since it will take around 1000 epochs, we suggest using a large
save_every_n to save disk space.
""",
)
parser.add_argument(
"--use-fp16",
type=str2bool,
default=False,
help="Whether to use half precision training.",
)
return parser
def get_data_statistics():
return AttributeDict(
{
"mel_mean": 0,
"mel_std": 1,
}
)
def _get_data_params() -> AttributeDict:
params = AttributeDict(
{
"name": "baker-zh",
"train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt",
"valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt",
# "batch_size": 64,
# "num_workers": 1,
# "pin_memory": False,
"cleaners": ["english_cleaners2"],
"add_blank": True,
"n_spks": 1,
"n_fft": 1024,
"n_feats": 80,
"sampling_rate": 22050,
"hop_length": 256,
"win_length": 1024,
"f_min": 0,
"f_max": 8000,
"seed": 1234,
"load_durations": False,
"data_statistics": get_data_statistics(),
}
)
return params
def _get_model_params() -> AttributeDict:
n_feats = 80
filter_channels_dp = 256
encoder_params_p_dropout = 0.1
params = AttributeDict(
{
"n_spks": 1, # for baker-zh.
"spk_emb_dim": 64,
"n_feats": n_feats,
"out_size": None, # or use 172
"prior_loss": True,
"use_precomputed_durations": False,
"data_statistics": get_data_statistics(),
"encoder": AttributeDict(
{
"encoder_type": "RoPE Encoder", # not used
"encoder_params": AttributeDict(
{
"n_feats": n_feats,
"n_channels": 192,
"filter_channels": 768,
"filter_channels_dp": filter_channels_dp,
"n_heads": 2,
"n_layers": 6,
"kernel_size": 3,
"p_dropout": encoder_params_p_dropout,
"spk_emb_dim": 64,
"n_spks": 1,
"prenet": True,
}
),
"duration_predictor_params": AttributeDict(
{
"filter_channels_dp": filter_channels_dp,
"kernel_size": 3,
"p_dropout": encoder_params_p_dropout,
}
),
}
),
"decoder": AttributeDict(
{
"channels": [256, 256],
"dropout": 0.05,
"attention_head_dim": 64,
"n_blocks": 1,
"num_mid_blocks": 2,
"num_heads": 2,
"act_fn": "snakebeta",
}
),
"cfm": AttributeDict(
{
"name": "CFM",
"solver": "euler",
"sigma_min": 1e-4,
}
),
"optimizer": AttributeDict(
{
"lr": 1e-4,
"weight_decay": 0.0,
}
),
}
)
return params
def get_params():
params = AttributeDict(
{
"model_args": _get_model_params(),
"data_args": _get_data_params(),
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": -1, # 0
"log_interval": 10,
"valid_interval": 1500,
"env_info": get_env_info(),
}
)
return params
def get_model(params):
m = MatchaTTS(**params.model_args)
return m
def load_checkpoint_if_available(
params: AttributeDict, model: nn.Module
) -> Optional[Dict[str, Any]]:
"""Load checkpoint from file.
If params.start_epoch is larger than 1, it will load the checkpoint from
`params.start_epoch - 1`.
Apart from loading state dict for `model` and `optimizer` it also updates
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
Returns:
Return a dict containing previously saved training info.
"""
if params.start_epoch > 1:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else:
return None
assert filename.is_file(), f"{filename} does not exist!"
saved_params = load_checkpoint(filename, model=model)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device, params):
"""Parse batch data"""
mel_mean = params.data_args.data_statistics.mel_mean
mel_std_inv = 1 / params.data_args.data_statistics.mel_std
for i in range(batch["features"].shape[0]):
n = batch["features_lens"][i]
batch["features"][i : i + 1, :n, :] = (
batch["features"][i : i + 1, :n, :] - mel_mean
) * mel_std_inv
batch["features"][i : i + 1, n:, :] = 0
audio = batch["audio"].to(device)
features = batch["features"].to(device)
audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device)
tokens = batch["tokens"]
tokens = tokenizer.texts_to_token_ids(tokens, intersperse_blank=True)
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
# a tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
max_feature_length = fix_len_compatibility(features.shape[1])
if max_feature_length > features.shape[1]:
pad = max_feature_length - features.shape[1]
features = torch.nn.functional.pad(features, (0, 0, 0, pad))
# features_lens[features_lens.argmax()] += pad
return audio, audio_lens, features, features_lens.long(), tokens, tokens_lens.long()
def compute_validation_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
tokenizer: Tokenizer,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
rank: int = 0,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses
# used to summary the stats over iterations
tot_loss = MetricsTracker()
with torch.no_grad():
for batch_idx, batch in enumerate(valid_dl):
(
audio,
audio_lens,
features,
features_lens,
tokens,
tokens_lens,
) = prepare_input(batch, tokenizer, device, params)
losses = get_losses(
{
"x": tokens,
"x_lengths": tokens_lens,
"y": features.permute(0, 2, 1),
"y_lengths": features_lens,
"spks": None, # should change it for multi-speakers
"durations": None,
}
)
batch_size = len(batch["tokens"])
loss_info = MetricsTracker()
loss_info["samples"] = batch_size
s = 0
for key, value in losses.items():
v = value.detach().item()
loss_info[key] = v * batch_size
s += v * batch_size
loss_info["tot_loss"] = s
# summary stats
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(device)
loss_value = tot_loss["tot_loss"] / tot_loss["samples"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: Union[nn.Module, DDP],
tokenizer: Tokenizer,
optimizer: Optimizer,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
tb_writer:
Writer to write log messages to tensorboard.
"""
model.train()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses
# used to track the stats over iterations in one epoch
tot_loss = MetricsTracker()
saved_bad_model = False
def save_bad_model(suffix: str = ""):
save_checkpoint(
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
model=model,
params=params,
optimizer=optimizer,
scaler=scaler,
rank=0,
)
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
# audio: (N, T), float32
# features: (N, T, C), float32
# audio_lens, (N,), int32
# features_lens, (N,), int32
# tokens: List[List[str]], len(tokens) == N
batch_size = len(batch["tokens"])
(
audio,
audio_lens,
features,
features_lens,
tokens,
tokens_lens,
) = prepare_input(batch, tokenizer, device, params)
try:
with autocast(enabled=params.use_fp16):
losses = get_losses(
{
"x": tokens,
"x_lengths": tokens_lens,
"y": features.permute(0, 2, 1),
"y_lengths": features_lens,
"spks": None, # should change it for multi-speakers
"durations": None,
}
)
loss = sum(losses.values())
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
loss_info = MetricsTracker()
loss_info["samples"] = batch_size
s = 0
for key, value in losses.items():
v = value.detach().item()
loss_info[key] = v * batch_size
s += v * batch_size
loss_info["tot_loss"] = s
tot_loss = tot_loss + loss_info
except: # noqa
save_bad_model()
raise
if params.batch_idx_train % 100 == 0 and params.use_fp16:
# If the grad scale was less than 1, try increasing it.
# The _growth_interval of the grad scaler is configurable,
# but we can't configure it to have different
# behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 8.0 or (
cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0
):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
if not saved_bad_model:
save_bad_model(suffix="-first-warning")
saved_bad_model = True
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if params.batch_idx_train % params.log_interval == 0:
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"global_batch_idx: {params.batch_idx_train}, "
f"batch size: {batch_size}, "
f"loss[{loss_info}], tot_loss[{tot_loss}], "
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
)
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if params.use_fp16:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
)
if params.batch_idx_train % params.valid_interval == 1:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
tokenizer=tokenizer,
valid_dl=valid_dl,
world_size=world_size,
rank=rank,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
"Maximum memory allocated so far is "
f"{torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
loss_value = tot_loss["tot_loss"] / tot_loss["samples"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
params = get_params()
params.update(vars(args))
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
tokenizer = Tokenizer(params.tokens)
params.pad_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
params.model_args.n_vocab = params.vocab_size
with open(params.cmvn) as f:
stats = json.load(f)
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
params.data_args.data_statistics.mel_std = stats["fbank_std"]
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
params.model_args.data_statistics.mel_std = stats["fbank_std"]
logging.info(params)
print(params)
logging.info("About to create model")
model = get_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of parameters: {num_param}")
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer = torch.optim.Adam(model.parameters(), **params.model_args.optimizer)
logging.info("About to create datamodule")
baker_zh = BakerZhTtsDataModule(args)
train_cuts = baker_zh.train_cuts()
train_dl = baker_zh.train_dataloaders(train_cuts)
valid_cuts = baker_zh.valid_cuts()
valid_dl = baker_zh.valid_dataloaders(valid_cuts)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs + 1):
logging.info(f"Start epoch {epoch}")
fix_random_seed(params.seed + epoch - 1)
if "sampler" in train_dl:
train_dl.sampler.set_epoch(epoch - 1)
params.cur_epoch = epoch
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
train_one_epoch(
params=params,
model=model,
tokenizer=tokenizer,
optimizer=optimizer,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
)
if epoch % params.save_every_n == 0 or epoch == params.num_epochs:
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint(
filename=filename,
params=params,
model=model,
optimizer=optimizer,
scaler=scaler,
rank=rank,
)
if rank == 0:
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
BakerZhTtsDataModule.add_arguments(parser)
args = parser.parse_args()
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
main()

View File

@ -0,0 +1,340 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from fbank import MatchaFbank, MatchaFbankConfig
from lhotse import CutSet, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
PrecomputedFeatures,
SimpleCutSampler,
SpeechSynthesisDataset,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class BakerZhTtsDataModule:
"""
DataModule for tts experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="TTS data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=False,
help="When enabled, each batch will have the "
"field: batch['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to create train dataset")
train = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = MatchaFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=sampling_rate,
hop_length=256,
win_length=1024,
f_min=0,
f_max=8000,
)
train = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=True,
pin_memory=True,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = MatchaFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=sampling_rate,
hop_length=256,
win_length=1024,
f_min=0,
f_max=8000,
)
validate = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
return_cuts=self.args.return_cuts,
)
else:
validate = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
num_buckets=self.args.num_buckets,
shuffle=False,
)
logging.info("About to create valid dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=True,
pin_memory=True,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.info("About to create test dataset")
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = MatchaFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=sampling_rate,
hop_length=256,
win_length=1024,
f_min=0,
f_max=8000,
)
test = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
return_cuts=self.args.return_cuts,
)
else:
test = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
test_sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
num_buckets=self.args.num_buckets,
shuffle=False,
)
logging.info("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=test_sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "baker_zh_cuts_train.jsonl.gz"
)
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get validation cuts")
return load_manifest_lazy(
self.args.manifest_dir / "baker_zh_cuts_valid.jsonl.gz"
)
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "baker_zh_cuts_test.jsonl.gz"
)

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/matcha/utils.py

151
egs/baker_zh/TTS/prepare.sh Executable file
View File

@ -0,0 +1,151 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
stage=-1
stop_stage=100
dl_dir=$PWD/download
mkdir -p $dl_dir
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "Stage -1: build monotonic_align lib (used by ./matcha)"
for recipe in matcha; do
if [ ! -d $recipe/monotonic_align/build ]; then
cd $recipe/monotonic_align
python3 setup.py build_ext --inplace
cd ../../
else
log "monotonic_align lib for $recipe already built"
fi
done
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
# The directory $dl_dir/BANSYP contains the following 3 directories
# ls -lh $dl_dir/BZNSYP/
# total 0
# drwxr-xr-x 10002 kuangfangjun root 0 Jan 4 2019 PhoneLabeling
# drwxr-xr-x 3 kuangfangjun root 0 Jan 31 2019 ProsodyLabeling
# drwxr-xr-x 10003 kuangfangjun root 0 Aug 26 17:45 Wave
# If you have trouble accessing huggingface.co, please use
#
# cd $dl_dir
# wget https://huggingface.co/openspeech/BZNSYP/resolve/main/BZNSYP.tar.bz2
# tar xf BZNSYP.tar.bz2
# cd ..
# If you have pre-downloaded it to /path/to/BZNSYP, you can create a symlink
#
# ln -sfv /path/to/BZNSYP $dl_dir/BZNSYP
#
if [ ! -d $dl_dir/BZNSYP/Wave ]; then
lhotse download baker-zh $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare baker-zh manifest"
# We assume that you have downloaded the baker corpus
# to $dl_dir/BZNSYP
mkdir -p data/manifests
if [ ! -e data/manifests/.baker-zh.done ]; then
lhotse prepare baker-zh $dl_dir/BZNSYP data/manifests
touch data/manifests/.baker-zh.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Generate tokens.txt"
if [ ! -e data/tokens.txt ]; then
python3 ./local/generate_tokens.py --tokens data/tokens.txt
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Generate raw cutset"
if [ ! -e data/manifests/baker_zh_cuts_raw.jsonl.gz ]; then
lhotse cut simple \
-r ./data/manifests/baker_zh_recordings_all.jsonl.gz \
-s ./data/manifests/baker_zh_supervisions_all.jsonl.gz \
./data/manifests/baker_zh_cuts_raw.jsonl.gz
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Convert text to tokens"
if [ ! -e data/manifests/baker_zh_cuts.jsonl.gz ]; then
python3 ./local/convert_text_to_tokens.py \
--in-file ./data/manifests/baker_zh_cuts_raw.jsonl.gz \
--out-file ./data/manifests/baker_zh_cuts.jsonl.gz
fi
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Generate fbank (used by ./matcha)"
mkdir -p data/fbank
if [ ! -e data/fbank/.baker-zh.done ]; then
./local/compute_fbank_baker_zh.py
touch data/fbank/.baker-zh.done
fi
if [ ! -e data/fbank/.baker-zh-validated.done ]; then
log "Validating data/fbank for baker-zh (used by ./matcha)"
python3 ./local/validate_manifest.py \
data/fbank/baker_zh_cuts.jsonl.gz
touch data/fbank/.baker-zh-validated.done
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Split the baker-zh cuts into train, valid and test sets (used by ./matcha)"
if [ ! -e data/fbank/.baker_zh_split.done ]; then
lhotse subset --last 600 \
data/fbank/baker_zh_cuts.jsonl.gz \
data/fbank/baker_zh_cuts_validtest.jsonl.gz
lhotse subset --first 100 \
data/fbank/baker_zh_cuts_validtest.jsonl.gz \
data/fbank/baker_zh_cuts_valid.jsonl.gz
lhotse subset --last 500 \
data/fbank/baker_zh_cuts_validtest.jsonl.gz \
data/fbank/baker_zh_cuts_test.jsonl.gz
rm data/fbank/baker_zh_cuts_validtest.jsonl.gz
n=$(( $(gunzip -c data/fbank/baker_zh_cuts.jsonl.gz | wc -l) - 600 ))
lhotse subset --first $n \
data/fbank/baker_zh_cuts.jsonl.gz \
data/fbank/baker_zh_cuts_train.jsonl.gz
touch data/fbank/.baker_zh_split.done
fi
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 6: Compute fbank mean and std (used by ./matcha)"
if [ ! -f ./data/fbank/cmvn.json ]; then
./local/compute_fbank_statistics.py ./data/fbank/baker_zh_cuts_train.jsonl.gz ./data/fbank/cmvn.json
fi
fi

1
egs/baker_zh/TTS/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared

View File

@ -339,7 +339,7 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
# 2. chmod +x ./jq # 2. chmod +x ./jq
# 3. cp jq /usr/bin # 3. cp jq /usr/bin
gunzip -c ${file} \ gunzip -c ${file} \
| jq '.text' | sed 's/"//g' > $lang_dir/transcript_words.txt | jq '.supervisions[].text' | sed 's/"//g' > $lang_dir/transcript_words.txt
# Ensure space only appears once # Ensure space only appears once
sed -i 's/\t/ /g' $lang_dir/transcript_words.txt sed -i 's/\t/ /g' $lang_dir/transcript_words.txt

View File

@ -161,14 +161,14 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Split XL subset into pieces (may take 30 minutes)" log "Stage 5: Split XL subset into pieces (may take 30 minutes)"
split_dir=data/fbank/XL_split split_dir=data/fbank/XL_split
if [ ! -f $split_dir/.split_completed ]; then if [ ! -f $split_dir/.split_completed ]; then
lhotse split-lazy ./data/fbank/cuts_XL_raw.jsonl.gz $split_dir $num_per_split lhotse split-lazy ./data/fbank/gigaspeech_cuts_XL_raw.jsonl.gz $split_dir $num_per_split
touch $split_dir/.split_completed touch $split_dir/.split_completed
fi fi
fi fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Compute features for XL" log "Stage 6: Compute features for XL"
num_splits=$(find data/fbank/XL_split -name "cuts_XL_raw.*.jsonl.gz" | wc -l) num_splits=$(find data/fbank/XL_split -name "gigaspeech_cuts_XL_raw.*.jsonl.gz" | wc -l)
python3 ./local/compute_fbank_gigaspeech_splits.py \ python3 ./local/compute_fbank_gigaspeech_splits.py \
--num-workers 20 \ --num-workers 20 \
--batch-duration 600 \ --batch-duration 600 \
@ -177,9 +177,9 @@ fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Combine features for XL (may take 3 hours)" log "Stage 7: Combine features for XL (may take 3 hours)"
if [ ! -f data/fbank/cuts_XL.jsonl.gz ]; then if [ ! -f data/fbank/gigaspeech_cuts_XL.jsonl.gz ]; then
pieces=$(find data/fbank/XL_split -name "cuts_XL.*.jsonl.gz") pieces=$(find data/fbank/XL_split -name "gigaspeech_cuts_XL.*.jsonl.gz")
lhotse combine $pieces data/fbank/cuts_XL.jsonl.gz lhotse combine $pieces data/fbank/gigaspeech_cuts_XL.jsonl.gz
fi fi
fi fi

View File

@ -260,7 +260,7 @@ def decode_one_batch(
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
""" """
device = model.device device = next(model.parameters()).device
feature = batch["inputs"] feature = batch["inputs"]
assert feature.ndim == 3 assert feature.ndim == 3

View File

@ -1,158 +0,0 @@
#!/usr/bin/env python3
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the musan dataset.
It looks for manifests in the directory `src_dir` (default is data/manifests).
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
import os
from pathlib import Path
import torch
from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
MonoCut,
WhisperFbank,
WhisperFbankConfig,
combine,
)
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def is_cut_long(c: MonoCut) -> bool:
return c.duration > 5
def compute_fbank_musan(
src_dir: str = "data/manifests",
num_mel_bins: int = 80,
whisper_fbank: bool = False,
output_dir: str = "data/fbank",
):
src_dir = Path(src_dir)
output_dir = Path(output_dir)
num_jobs = min(15, os.cpu_count())
dataset_parts = (
"music",
"speech",
"noise",
)
prefix = "musan"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
musan_cuts_path = output_dir / "musan_cuts.jsonl.gz"
if musan_cuts_path.is_file():
logging.info(f"{musan_cuts_path} already exists - skipping")
return
logging.info("Extracting features for Musan")
if whisper_fbank:
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
# create chunks of Musan with duration 5 - 10 seconds
musan_cuts = (
CutSet.from_manifests(
recordings=combine(part["recordings"] for part in manifests.values())
)
.cut_into_windows(10.0)
.filter(is_cut_long)
.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/musan_feats",
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
)
musan_cuts.to_file(musan_cuts_path)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--src-dir",
type=str,
default="data/manifests",
help="Source manifests directory.",
)
parser.add_argument(
"--num-mel-bins",
type=int,
default=80,
help="""The number of mel bins for Fbank""",
)
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
parser.add_argument(
"--output-dir",
type=str,
default="data/fbank",
help="Output directory. Default: data/fbank.",
)
return parser.parse_args()
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
compute_fbank_musan(
src_dir=args.src_dir,
num_mel_bins=args.num_mel_bins,
whisper_fbank=args.whisper_fbank,
output_dir=args.output_dir,
)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/compute_fbank_musan.py

View File

@ -1,157 +0,0 @@
#!/usr/bin/env python3
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script removes short and long utterances from a cutset.
Caution:
You may need to tune the thresholds for your own dataset.
Usage example:
python3 ./local/filter_cuts.py \
--bpe-model data/lang_bpe_5000/bpe.model \
--in-cuts data/fbank/speechtools_cuts_test.jsonl.gz \
--out-cuts data/fbank-filtered/speechtools_cuts_test.jsonl.gz
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
from lhotse import CutSet, load_manifest_lazy
from lhotse.cut import Cut
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--bpe-model",
type=Path,
help="Path to the bpe.model",
)
parser.add_argument(
"--in-cuts",
type=Path,
help="Path to the input cutset",
)
parser.add_argument(
"--out-cuts",
type=Path,
help="Path to the output cutset",
)
return parser.parse_args()
def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor):
total = 0 # number of total utterances before removal
removed = 0 # number of removed utterances
def remove_short_and_long_utterances(c: Cut):
"""Return False to exclude the input cut"""
nonlocal removed, total
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ./display_manifest_statistics.py
#
# You should use ./display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
total += 1
if c.duration < 1.0 or c.duration > 20.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
)
removed += 1
return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./pruned_transducer_stateless2/conformer.py, the
# conv module uses the following expression
# for subsampling
if c.num_frames is None:
num_frames = c.duration * 100 # approximate
else:
num_frames = c.num_frames
T = ((num_frames - 1) // 2 - 1) // 2
# Note: for ./lstm_transducer_stateless/lstm.py, the formula is
# T = ((num_frames - 3) // 2 - 1) // 2
# Note: for ./pruned_transducer_stateless7/zipformer.py, the formula is
# T = ((num_frames - 7) // 2 + 1) // 2
tokens = sp.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}"
)
removed += 1
return False
return True
# We use to_eager() here so that we can print out the value of total
# and removed below.
ans = cut_set.filter(remove_short_and_long_utterances).to_eager()
ratio = removed / total * 100
logging.info(
f"Removed {removed} cuts from {total} cuts. {ratio:.3f}% data is removed."
)
return ans
def main():
args = get_args()
logging.info(vars(args))
if args.out_cuts.is_file():
logging.info(f"{args.out_cuts} already exists - skipping")
return
assert args.in_cuts.is_file(), f"{args.in_cuts} does not exist"
assert args.bpe_model.is_file(), f"{args.bpe_model} does not exist"
sp = spm.SentencePieceProcessor()
sp.load(str(args.bpe_model))
cut_set = load_manifest_lazy(args.in_cuts)
assert isinstance(cut_set, CutSet)
cut_set = filter_cuts(cut_set, sp)
logging.info(f"Saving to {args.out_cuts}")
args.out_cuts.parent.mkdir(parents=True, exist_ok=True)
cut_set.to_file(args.out_cuts)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/filter_cuts.py

View File

@ -1,115 +0,0 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# You can install sentencepiece via:
#
# pip install sentencepiece
#
# Due to an issue reported in
# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030
#
# Please install a version >=0.1.96
import argparse
import shutil
from pathlib import Path
from typing import Dict
import sentencepiece as spm
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
The generated bpe.model is saved to this directory.
""",
)
parser.add_argument(
"--transcript",
type=str,
help="Training transcript.",
)
parser.add_argument(
"--vocab-size",
type=int,
help="Vocabulary size for BPE training",
)
return parser.parse_args()
def generate_tokens(lang_dir: Path):
"""
Generate the tokens.txt from a bpe model.
"""
sp = spm.SentencePieceProcessor()
sp.load(str(lang_dir / "bpe.model"))
token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())}
with open(lang_dir / "tokens.txt", "w", encoding="utf-8") as f:
for sym, i in token2id.items():
f.write(f"{sym} {i}\n")
def main():
args = get_args()
vocab_size = args.vocab_size
lang_dir = Path(args.lang_dir)
model_type = "unigram"
model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
train_text = args.transcript
character_coverage = 1.0
input_sentence_size = 100000000
user_defined_symbols = ["<blk>", "<sos/eos>"]
unk_id = len(user_defined_symbols)
# Note: unk_id is fixed to 2.
# If you change it, you should also change other
# places that are using it.
model_file = Path(model_prefix + ".model")
if not model_file.is_file():
spm.SentencePieceTrainer.train(
input=train_text,
vocab_size=vocab_size,
model_type=model_type,
model_prefix=model_prefix,
input_sentence_size=input_sentence_size,
character_coverage=character_coverage,
user_defined_symbols=user_defined_symbols,
unk_id=unk_id,
bos_id=-1,
eos_id=-1,
)
else:
print(f"{model_file} exists - skipping")
return
shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
generate_tokens(lang_dir)
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/train_bpe_model.py

View File

@ -1,101 +0,0 @@
#!/usr/bin/env python3
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script checks the following assumptions of the generated manifest:
- Single supervision per cut
- Supervision time bounds are within cut time bounds
We will add more checks later if needed.
Usage example:
python3 ./local/validate_manifest.py \
./data/fbank/speechtools_cuts_train.jsonl.gz
"""
import argparse
import logging
from pathlib import Path
from lhotse import CutSet, load_manifest_lazy
from lhotse.cut import Cut
from lhotse.dataset.speech_recognition import validate_for_asr
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"manifest",
type=Path,
help="Path to the manifest file",
)
return parser.parse_args()
def validate_one_supervision_per_cut(c: Cut):
if len(c.supervisions) != 1:
raise ValueError(f"{c.id} has {len(c.supervisions)} supervisions")
def validate_supervision_and_cut_time_bounds(c: Cut):
tol = 2e-3 # same tolerance as in 'validate_for_asr()'
s = c.supervisions[0]
# Supervision start time is relative to Cut ...
# https://lhotse.readthedocs.io/en/v0.10_e/cuts.html
if s.start < -tol:
raise ValueError(
f"{c.id}: Supervision start time {s.start} must not be negative."
)
if s.start > tol:
raise ValueError(
f"{c.id}: Supervision start time {s.start} is not at the beginning of the Cut. Please apply `lhotse cut trim-to-supervisions`."
)
if c.start + s.end > c.end + tol:
raise ValueError(
f"{c.id}: Supervision end time {c.start+s.end} is larger "
f"than cut end time {c.end}"
)
def main():
args = get_args()
manifest = args.manifest
logging.info(f"Validating {manifest}")
assert manifest.is_file(), f"{manifest} does not exist"
cut_set = load_manifest_lazy(manifest)
assert isinstance(cut_set, CutSet)
for c in cut_set:
validate_one_supervision_per_cut(c)
validate_supervision_and_cut_time_bounds(c)
# Validation from K2 training
# - checks supervision start is 0
# - checks supervision.duration is not longer than cut.duration
# - there is tolerance 2ms
validate_for_asr(cut_set)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/validate_manifest.py

View File

@ -1 +0,0 @@
This recipe implements Zipformer model.

View File

@ -0,0 +1,303 @@
#!/usr/bin/env python3
"""
Usage:
1. Download pre-trained models from
https://huggingface.co/desh2608/icefall-surt-libricss-dprnn-zipformer
2.
./dprnn_zipformer/pretrained.py \
--checkpoint /path/to/pretrained.pt \
--tokens /path/to/data/lang_bpe_500/tokens.txt \
/path/to/foo.wav
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from beam_search import (
beam_search,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_surt_model
from icefall.utils import num_tokens
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--tokens",
type=str,
required=True,
help="""Path to tokens.txt.""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_surt_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)
B, T, F = features.shape
processed = model.mask_encoder(features) # B,T,F*num_channels
masks = processed.view(B, T, F, params.num_channels).unbind(dim=-1)
x_masked = [features * m for m in masks]
# Recognition
# Concatenate the inputs along the batch axis
h = torch.cat(x_masked, dim=0)
h_lens = feature_lengths.repeat(params.num_channels)
encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens)
if model.joint_encoder_layer is not None:
encoder_out = model.joint_encoder_layer(encoder_out)
def _group_channels(hyps: List[str]) -> List[List[str]]:
"""
Currently we have a batch of size M*B, where M is the number of
channels and B is the batch size. We need to group the hypotheses
into B groups, each of which contains M hypotheses.
Example:
hyps = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2']
_group_channels(hyps) = [['a1', 'a2'], ['b1', 'b2'], ['c1', 'c2']]
"""
assert len(hyps) == B * params.num_channels
out_hyps = []
for i in range(B):
out_hyps.append(hyps[i::B])
return out_hyps
hyps = []
msg = f"Using {params.method}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
hyps.append(token_ids_to_words(hyp))
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -62,9 +62,7 @@ from asr_datamodule import LibriCssAsrDataModule
from decoder import Decoder from decoder import Decoder
from dprnn import DPRNN from dprnn import DPRNN
from einops.layers.torch import Rearrange from einops.layers.torch import Rearrange
from graph_pit.loss.optimized import optimized_graph_pit_mse_loss as gpit_mse
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import LOG_EPSILON, fix_random_seed from lhotse.utils import LOG_EPSILON, fix_random_seed
from model import SURT from model import SURT

View File

@ -20,6 +20,8 @@ import json
import sys import sys
from pathlib import Path from pathlib import Path
from icefall.utils import str2bool
def simple_cleanup(text: str) -> str: def simple_cleanup(text: str) -> str:
table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]") table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]")
@ -29,17 +31,21 @@ def simple_cleanup(text: str) -> str:
# Assign text of the supervisions and remove unnecessary entries. # Assign text of the supervisions and remove unnecessary entries.
def main(): def main():
assert len(sys.argv) == 3, "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR" assert (
len(sys.argv) == 4
), "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR KEEP_CUSTOM_FIELDS"
fname = Path(sys.argv[1]).name fname = Path(sys.argv[1]).name
oname = Path(sys.argv[2]) / fname oname = Path(sys.argv[2]) / fname
keep_custom_fields = str2bool(sys.argv[3])
with gzip.open(sys.argv[1], "r") as fin, gzip.open(oname, "w") as fout: with gzip.open(sys.argv[1], "r") as fin, gzip.open(oname, "w") as fout:
for line in fin: for line in fin:
cut = json.loads(line) cut = json.loads(line)
cut["supervisions"][0]["text"] = simple_cleanup( cut["supervisions"][0]["text"] = simple_cleanup(
cut["supervisions"][0]["custom"]["texts"][0] cut["supervisions"][0]["custom"]["texts"][0]
) )
del cut["supervisions"][0]["custom"] if not keep_custom_fields:
del cut["custom"] del cut["supervisions"][0]["custom"]
del cut["custom"]
fout.write((json.dumps(cut) + "\n").encode()) fout.write((json.dumps(cut) + "\n").encode())

View File

@ -29,6 +29,11 @@ export CUDA_VISIBLE_DEVICES=""
# - speech # - speech
dl_dir=$PWD/download dl_dir=$PWD/download
# If you want to do PromptASR experiments, please set it to True
# as this will keep the texts and pre_text information required for
# the training of PromptASR.
keep_custom_fields=False
. shared/parse_options.sh || exit 1 . shared/parse_options.sh || exit 1
# vocab size for sentence piece models. # vocab size for sentence piece models.
@ -134,7 +139,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
for subset in small medium large dev test_clean test_other; do for subset in small medium large dev test_clean test_other; do
if [ ! -e $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then if [ ! -e $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then
log "Prepare manifest for subset : ${subset}" log "Prepare manifest for subset : ${subset}"
./local/prepare_manifest.py $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz $manifests_dir ./local/prepare_manifest.py $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz $manifests_dir $keep_custom_fields
fi fi
done done
fi fi

View File

@ -50,7 +50,7 @@ We place an additional Conv1d layer right after the input embedding layer.
| `conformer-ctc2` | Reworked Conformer | Use auxiliary attention head | | `conformer-ctc2` | Reworked Conformer | Use auxiliary attention head |
| `conformer-ctc3` | Reworked Conformer | Streaming version + delay penalty | | `conformer-ctc3` | Reworked Conformer | Streaming version + delay penalty |
| `zipformer-ctc` | Zipformer | Use auxiliary attention head | | `zipformer-ctc` | Zipformer | Use auxiliary attention head |
| `zipformer` | Upgraded Zipformer | Use auxiliary transducer head / attention-decoder head | The latest recipe | | `zipformer` | Upgraded Zipformer | Use auxiliary transducer head / attention-decoder head (the latest recipe) |
# MMI # MMI
@ -58,3 +58,9 @@ We place an additional Conv1d layer right after the input embedding layer.
|------------------------------|-----------|---------------------------------------------------| |------------------------------|-----------|---------------------------------------------------|
| `conformer-mmi` | Conformer | | | `conformer-mmi` | Conformer | |
| `zipformer-mmi` | Zipformer | CTC warmup + use HP as decoding graph for decoding | | `zipformer-mmi` | Zipformer | CTC warmup + use HP as decoding graph for decoding |
# CR-CTC
| | Encoder | Comment |
|------------------------------|--------------------|------------------------------|
| `zipformer` | Upgraded Zipformer | Could also be an auxiliary loss to improve transducer or CTC/AED (the latest recipe) |

View File

@ -1,5 +1,318 @@
## Results ## Results
### zipformer (zipformer + pruned-transducer w/ CR-CTC)
See <https://github.com/k2-fsa/icefall/pull/1766> for more details.
[zipformer](./zipformer)
#### Non-streaming
##### large-scale model, number of model parameters: 148824074, i.e., 148.8 M
You can find a pretrained model, training logs, decoding logs, and decoding results at:
<https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-large-transducer-with-CR-CTC-20241019>
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
| decoding method | test-clean | test-other | comment |
|--------------------------------------|------------|------------|---------------------|
| greedy_search | 1.9 | 3.96 | --epoch 50 --avg 26 |
| modified_beam_search | 1.88 | 3.95 | --epoch 50 --avg 26 |
The training command using 2 80G-A100 GPUs is:
```bash
export CUDA_VISIBLE_DEVICES="0,1"
# for non-streaming model training:
./zipformer/train.py \
--world-size 2 \
--num-epochs 50 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp-large-cr-ctc-rnnt \
--use-cr-ctc 1 \
--use-ctc 1 \
--use-transducer 1 \
--use-attention-decoder 0 \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192 \
--ctc-loss-scale 0.1 \
--enable-spec-aug 0 \
--cr-loss-scale 0.02 \
--time-mask-ratio 2.5 \
--full-libri 1 \
--max-duration 1400 \
--master-port 12345
```
The decoding command is:
```bash
export CUDA_VISIBLE_DEVICES="0"
for m in greedy_search modified_beam_search; do
./zipformer/decode.py \
--epoch 50 \
--avg 26 \
--exp-dir zipformer/exp-large-cr-ctc-rnnt \
--use-cr-ctc 1 \
--use-ctc 1 \
--use-transducer 1 \
--use-attention-decoder 0 \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192 \
--max-duration 300 \
--decoding-method $m
done
```
### zipformer (zipformer + CR-CTC-AED)
See <https://github.com/k2-fsa/icefall/pull/1766> for more details.
[zipformer](./zipformer)
#### Non-streaming
##### large-scale model, number of model parameters: 174319650, i.e., 174.3 M
You can find a pretrained model, training logs, decoding logs, and decoding results at:
<https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-large-cr-ctc-aed-20241020>
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
| decoding method | test-clean | test-other | comment |
|--------------------------------------|------------|------------|---------------------|
| attention-decoder-rescoring-no-ngram | 1.96 | 4.08 | --epoch 50 --avg 20 |
The training command using 2 80G-A100 GPUs is:
```bash
export CUDA_VISIBLE_DEVICES="0,1"
# for non-streaming model training:
./zipformer/train.py \
--world-size 2 \
--num-epochs 50 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp-large-cr-ctc-aed \
--use-cr-ctc 1 \
--use-ctc 1 \
--use-transducer 0 \
--use-attention-decoder 1 \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192 \
--ctc-loss-scale 0.1 \
--attention-decoder-loss-scale 0.9 \
--enable-spec-aug 0 \
--cr-loss-scale 0.02 \
--time-mask-ratio 2.5 \
--full-libri 1 \
--max-duration 1200 \
--master-port 12345
```
The decoding command is:
```bash
export CUDA_VISIBLE_DEVICES="0"
./zipformer/ctc_decode.py \
--epoch 50 \
--avg 20 \
--exp-dir zipformer/exp-large-cr-ctc-aed/ \
--use-cr-ctc 1 \
--use-ctc 1 \
--use-transducer 0 \
--use-attention-decoder 1 \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192 \
--max-duration 200 \
--decoding-method attention-decoder-rescoring-no-ngram
done
```
### zipformer (zipformer + CR-CTC)
See <https://github.com/k2-fsa/icefall/pull/1766> for more details.
[zipformer](./zipformer)
#### Non-streaming
##### small-scale model, number of model parameters: 22118279, i.e., 22.1 M
You can find a pretrained model, training logs, decoding logs, and decoding results at:
<https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-small-cr-ctc-20241018>
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
| decoding method | test-clean | test-other | comment |
|--------------------------------------|------------|------------|---------------------|
| ctc-greedy-decoding | 2.57 | 5.95 | --epoch 50 --avg 25 |
| ctc-prefix-beam-search | 2.52 | 5.85 | --epoch 50 --avg 25 |
The training command using 2 32G-V100 GPUs is:
```bash
export CUDA_VISIBLE_DEVICES="0,1"
# for non-streaming model training:
./zipformer/train.py \
--world-size 2 \
--num-epochs 50 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp-small/ \
--use-cr-ctc 1 \
--use-ctc 1 \
--use-transducer 0 \
--use-attention-decoder 0 \
--num-encoder-layers 2,2,2,2,2,2 \
--feedforward-dim 512,768,768,768,768,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192 \
--base-lr 0.04 \
--enable-spec-aug 0 \
--cr-loss-scale 0.2 \
--time-mask-ratio 2.5 \
--full-libri 1 \
--max-duration 850 \
--master-port 12345
```
The decoding command is:
```bash
export CUDA_VISIBLE_DEVICES="0"
for m in ctc-greedy-search ctc-prefix-beam-search; do
./zipformer/ctc_decode.py \
--epoch 50 \
--avg 25 \
--exp-dir zipformer/exp-small \
--use-cr-ctc 1 \
--use-ctc 1 \
--use-transducer 0 \
--use-attention-decoder 0 \
--num-encoder-layers 2,2,2,2,2,2 \
--feedforward-dim 512,768,768,768,768,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192 \
--max-duration 600 \
--decoding-method $m
done
```
##### medium-scale model, number of model parameters: 64250603, i.e., 64.3 M
You can find a pretrained model, training logs, decoding logs, and decoding results at:
<https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-medium-cr-ctc-20241018>
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
| decoding method | test-clean | test-other | comment |
|--------------------------------------|------------|------------|---------------------|
| ctc-greedy-decoding | 2.12 | 4.62 | --epoch 50 --avg 24 |
| ctc-prefix-beam-search | 2.1 | 4.61 | --epoch 50 --avg 24 |
The training command using 4 32G-V100 GPUs is:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3"
# For non-streaming model training:
./zipformer/train.py \
--world-size 4 \
--num-epochs 50 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--use-cr-ctc 1 \
--use-ctc 1 \
--use-transducer 0 \
--use-attention-decoder 0 \
--enable-spec-aug 0 \
--cr-loss-scale 0.2 \
--time-mask-ratio 2.5 \
--full-libri 1 \
--max-duration 700 \
--master-port 12345
```
The decoding command is:
```bash
export CUDA_VISIBLE_DEVICES="0"
for m in ctc-greedy-search ctc-prefix-beam-search; do
./zipformer/ctc_decode.py \
--epoch 50 \
--avg 24 \
--exp-dir zipformer/exp \
--use-cr-ctc 1 \
--use-ctc 1 \
--use-transducer 0 \
--use-attention-decoder 0 \
--max-duration 600 \
--decoding-method $m
done
```
##### large-scale model, number of model parameters: 147010094, i.e., 147.0 M
You can find a pretrained model, training logs, decoding logs, and decoding results at:
<https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-large-cr-ctc-20241018>
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
| decoding method | test-clean | test-other | comment |
|--------------------------------------|------------|------------|---------------------|
| ctc-greedy-decoding | 2.03 | 4.37 | --epoch 50 --avg 26 |
| ctc-prefix-beam-search | 2.02 | 4.35 | --epoch 50 --avg 26 |
The training command using 2 80G-A100 GPUs is:
```bash
export CUDA_VISIBLE_DEVICES="0,1"
# For non-streaming model training:
./zipformer/train.py \
--world-size 2 \
--num-epochs 50 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp-large \
--use-cr-ctc 1 \
--use-ctc 1 \
--use-transducer 0 \
--use-attention-decoder 0 \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192 \
--enable-spec-aug 0 \
--cr-loss-scale 0.2 \
--time-mask-ratio 2.5 \
--full-libri 1 \
--max-duration 1400 \
--master-port 12345
```
The decoding command is:
```bash
export CUDA_VISIBLE_DEVICES="0"
for m in ctc-greedy-search ctc-prefix-beam-search; do
./zipformer/ctc_decode.py \
--epoch 50 \
--avg 26 \
--exp-dir zipformer/exp-large \
--use-cr-ctc 1 \
--use-ctc 1 \
--use-transducer 0 \
--use-attention-decoder 0 \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192 \
--max-duration 600 \
--decoding-method $m
done
```
### zipformer (zipformer + CTC/AED) ### zipformer (zipformer + CTC/AED)
See <https://github.com/k2-fsa/icefall/pull/1389> for more details. See <https://github.com/k2-fsa/icefall/pull/1389> for more details.
@ -307,6 +620,23 @@ done
To decode with external language models, please refer to the documentation [here](https://k2-fsa.github.io/icefall/decoding-with-langugage-models/index.html). To decode with external language models, please refer to the documentation [here](https://k2-fsa.github.io/icefall/decoding-with-langugage-models/index.html).
We also support training Zipformer with AMP+bf16 format (requires bf16 support). See [here](https://github.com/k2-fsa/icefall/pull/1700) for more details and pre-trained models. **The same command can be used for decoding and exporting the model.**
The amp+bf16 training command is:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./zipformer/train.py \
--world-size 4 \
--num-epochs 50 \
--start-epoch 1 \
--use-fp16 0 \
--use-bf16 1 \
--exp-dir zipformer/exp_amp_bf16 \
--causal 0 \
--full-libri 1 \
--max-duration 1000
```
##### small-scaled model, number of model parameters: 23285615, i.e., 23.3 M ##### small-scaled model, number of model parameters: 23285615, i.e., 23.3 M
The tensorboard log can be found at The tensorboard log can be found at

View File

@ -32,7 +32,7 @@ class Conformer(Transformer):
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
d_model (int): attention dimension d_model (int): attention dimension
nhead (int): number of head nhead (int): number of head
dim_feedforward (int): feedforward dimention dim_feedforward (int): feedforward dimension
num_encoder_layers (int): number of encoder layers num_encoder_layers (int): number of encoder layers
num_decoder_layers (int): number of decoder layers num_decoder_layers (int): number of decoder layers
dropout (float): dropout rate dropout (float): dropout rate
@ -902,7 +902,7 @@ class Swish(torch.nn.Module):
"""Construct an Swish object.""" """Construct an Swish object."""
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
"""Return Swich activation function.""" """Return Swish activation function."""
return x * torch.sigmoid(x) return x * torch.sigmoid(x)

View File

@ -42,7 +42,7 @@ class Conformer(Transformer):
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
d_model (int): attention dimension, also the output dimension d_model (int): attention dimension, also the output dimension
nhead (int): number of head nhead (int): number of head
dim_feedforward (int): feedforward dimention dim_feedforward (int): feedforward dimension
num_encoder_layers (int): number of encoder layers num_encoder_layers (int): number of encoder layers
num_decoder_layers (int): number of decoder layers num_decoder_layers (int): number of decoder layers
dropout (float): dropout rate dropout (float): dropout rate

View File

@ -33,7 +33,7 @@ class Conformer(Transformer):
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
d_model (int): attention dimension d_model (int): attention dimension
nhead (int): number of head nhead (int): number of head
dim_feedforward (int): feedforward dimention dim_feedforward (int): feedforward dimension
num_encoder_layers (int): number of encoder layers num_encoder_layers (int): number of encoder layers
num_decoder_layers (int): number of decoder layers num_decoder_layers (int): number of decoder layers
dropout (float): dropout rate dropout (float): dropout rate

View File

@ -42,7 +42,7 @@ class Conformer(EncoderInterface):
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
d_model (int): attention dimension, also the output dimension d_model (int): attention dimension, also the output dimension
nhead (int): number of head nhead (int): number of head
dim_feedforward (int): feedforward dimention dim_feedforward (int): feedforward dimension
num_encoder_layers (int): number of encoder layers num_encoder_layers (int): number of encoder layers
dropout (float): dropout rate dropout (float): dropout rate
layer_dropout (float): layer-dropout rate. layer_dropout (float): layer-dropout rate.

View File

@ -42,7 +42,7 @@ class Conformer(EncoderInterface):
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
d_model (int): attention dimension, also the output dimension d_model (int): attention dimension, also the output dimension
nhead (int): number of head nhead (int): number of head
dim_feedforward (int): feedforward dimention dim_feedforward (int): feedforward dimension
num_encoder_layers (int): number of encoder layers num_encoder_layers (int): number of encoder layers
dropout (float): dropout rate dropout (float): dropout rate
layer_dropout (float): layer-dropout rate. layer_dropout (float): layer-dropout rate.

View File

@ -42,7 +42,7 @@ class Conformer(EncoderInterface):
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
d_model (int): attention dimension, also the output dimension d_model (int): attention dimension, also the output dimension
nhead (int): number of head nhead (int): number of head
dim_feedforward (int): feedforward dimention dim_feedforward (int): feedforward dimension
num_encoder_layers (int): number of encoder layers num_encoder_layers (int): number of encoder layers
dropout (float): dropout rate dropout (float): dropout rate
layer_dropout (float): layer-dropout rate. layer_dropout (float): layer-dropout rate.

View File

@ -42,7 +42,7 @@ class Conformer(EncoderInterface):
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
d_model (int): attention dimension, also the output dimension d_model (int): attention dimension, also the output dimension
nhead (int): number of head nhead (int): number of head
dim_feedforward (int): feedforward dimention dim_feedforward (int): feedforward dimension
num_encoder_layers (int): number of encoder layers num_encoder_layers (int): number of encoder layers
dropout (float): dropout rate dropout (float): dropout rate
layer_dropout (float): layer-dropout rate. layer_dropout (float): layer-dropout rate.

View File

@ -69,7 +69,7 @@ class Conformer(Transformer):
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
d_model (int): attention dimension d_model (int): attention dimension
nhead (int): number of head nhead (int): number of head
dim_feedforward (int): feedforward dimention dim_feedforward (int): feedforward dimension
num_encoder_layers (int): number of encoder layers num_encoder_layers (int): number of encoder layers
num_decoder_layers (int): number of decoder layers num_decoder_layers (int): number of decoder layers
dropout (float): dropout rate dropout (float): dropout rate

View File

@ -35,7 +35,7 @@ class Conformer(Transformer):
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
d_model (int): attention dimension d_model (int): attention dimension
nhead (int): number of head nhead (int): number of head
dim_feedforward (int): feedforward dimention dim_feedforward (int): feedforward dimension
num_encoder_layers (int): number of encoder layers num_encoder_layers (int): number of encoder layers
dropout (float): dropout rate dropout (float): dropout rate
cnn_module_kernel (int): Kernel size of convolution module cnn_module_kernel (int): Kernel size of convolution module

View File

@ -236,7 +236,7 @@ class TransformerDecoder(nn.Module):
causal_mask = subsequent_mask(x.shape[0], device=x.device) # (seq_len, seq_len) causal_mask = subsequent_mask(x.shape[0], device=x.device) # (seq_len, seq_len)
attn_mask = torch.logical_or( attn_mask = torch.logical_or(
padding_mask.unsqueeze(1), # (batch, 1, seq_len) padding_mask.unsqueeze(1), # (batch, 1, seq_len)
torch.logical_not(causal_mask).unsqueeze(0) # (1, seq_len, seq_len) torch.logical_not(causal_mask).unsqueeze(0), # (1, seq_len, seq_len)
) # (batch, seq_len, seq_len) ) # (batch, seq_len, seq_len)
if memory is not None: if memory is not None:
@ -367,7 +367,9 @@ class MultiHeadAttention(nn.Module):
self.num_heads = num_heads self.num_heads = num_heads
self.head_dim = attention_dim // num_heads self.head_dim = attention_dim // num_heads
assert self.head_dim * num_heads == attention_dim, ( assert self.head_dim * num_heads == attention_dim, (
self.head_dim, num_heads, attention_dim self.head_dim,
num_heads,
attention_dim,
) )
self.dropout = dropout self.dropout = dropout
self.name = None # will be overwritten in training code; for diagnostics. self.name = None # will be overwritten in training code; for diagnostics.
@ -437,15 +439,19 @@ class MultiHeadAttention(nn.Module):
if key_padding_mask is not None: if key_padding_mask is not None:
assert key_padding_mask.shape == (batch, src_len), key_padding_mask.shape assert key_padding_mask.shape == (batch, src_len), key_padding_mask.shape
attn_weights = attn_weights.masked_fill( attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"), key_padding_mask.unsqueeze(1).unsqueeze(2),
float("-inf"),
) )
if attn_mask is not None: if attn_mask is not None:
assert ( assert attn_mask.shape == (batch, 1, src_len) or attn_mask.shape == (
attn_mask.shape == (batch, 1, src_len) batch,
or attn_mask.shape == (batch, tgt_len, src_len) tgt_len,
src_len,
), attn_mask.shape ), attn_mask.shape
attn_weights = attn_weights.masked_fill(attn_mask.unsqueeze(1), float("-inf")) attn_weights = attn_weights.masked_fill(
attn_mask.unsqueeze(1), float("-inf")
)
attn_weights = attn_weights.view(batch * num_heads, tgt_len, src_len) attn_weights = attn_weights.view(batch * num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1)
@ -456,7 +462,11 @@ class MultiHeadAttention(nn.Module):
# (batch * head, tgt_len, head_dim) # (batch * head, tgt_len, head_dim)
attn_output = torch.bmm(attn_weights, v) attn_output = torch.bmm(attn_weights, v)
assert attn_output.shape == (batch * num_heads, tgt_len, head_dim), attn_output.shape assert attn_output.shape == (
batch * num_heads,
tgt_len,
head_dim,
), attn_output.shape
attn_output = attn_output.transpose(0, 1).contiguous() attn_output = attn_output.transpose(0, 1).contiguous()
attn_output = attn_output.view(tgt_len, batch, num_heads * head_dim) attn_output = attn_output.view(tgt_len, batch, num_heads * head_dim)

View File

@ -111,6 +111,7 @@ Usage:
import argparse import argparse
import logging import logging
import math import math
import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -120,6 +121,7 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from lhotse import set_caching_enabled
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import ( from icefall.checkpoint import (
@ -128,8 +130,12 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.context_graph import ContextGraph, ContextState
from icefall.decode import ( from icefall.decode import (
ctc_greedy_search, ctc_greedy_search,
ctc_prefix_beam_search,
ctc_prefix_beam_search_attention_decoder_rescoring,
ctc_prefix_beam_search_shallow_fussion,
get_lattice, get_lattice,
nbest_decoding, nbest_decoding,
nbest_oracle, nbest_oracle,
@ -140,6 +146,8 @@ from icefall.decode import (
rescore_with_whole_lattice, rescore_with_whole_lattice,
) )
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.lm_wrapper import LmScorer
from icefall.ngram_lm import NgramLm, NgramLmStateCost
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_texts, get_texts,
@ -254,6 +262,12 @@ def get_parser():
lattice, rescore them with the attention decoder. lattice, rescore them with the attention decoder.
- (9) attention-decoder-rescoring-with-ngram. Extract n paths from the LM - (9) attention-decoder-rescoring-with-ngram. Extract n paths from the LM
rescored lattice, rescore them with the attention decoder. rescored lattice, rescore them with the attention decoder.
- (10) ctc-prefix-beam-search. Extract n paths with the given beam, the best
path of the n paths is the decoding result.
- (11) ctc-prefix-beam-search-attention-decoder-rescoring. Extract n paths with
the given beam, rescore them with the attention decoder.
- (12) ctc-prefix-beam-search-shallow-fussion. Use NNLM shallow fussion during
beam search, LODR and hotwords are also supported in this decoding method.
""", """,
) )
@ -279,6 +293,23 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--nnlm-type",
type=str,
default="rnn",
help="Type of NN lm",
choices=["rnn", "transformer"],
)
parser.add_argument(
"--nnlm-scale",
type=float,
default=0,
help="""The scale of the neural network LM, 0 means don't use nnlm shallow fussion.
Used only when `--use-shallow-fusion` is set to True.
""",
)
parser.add_argument( parser.add_argument(
"--hlg-scale", "--hlg-scale",
type=float, type=float,
@ -296,6 +327,54 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--backoff-id",
type=int,
default=500,
help="ID of the backoff symbol in the ngram LM",
)
parser.add_argument(
"--lodr-ngram",
type=str,
help="The path to the lodr ngram",
)
parser.add_argument(
"--lodr-lm-scale",
type=float,
default=0,
help="The scale of lodr ngram, should be less than 0. 0 means don't use lodr.",
)
parser.add_argument(
"--context-score",
type=float,
default=0,
help="""
The bonus score of each token for the context biasing words/phrases.
0 means don't use contextual biasing.
Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion.
""",
)
parser.add_argument(
"--context-file",
type=str,
default="",
help="""
The path of the context biasing lists, one word/phrase each line
Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion.
""",
)
parser.add_argument(
"--skip-scoring",
type=str2bool,
default=False,
help="""Skip scoring, but still save the ASR output (for eval sets).""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -306,11 +385,12 @@ def get_decoding_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
"frame_shift_ms": 10, "frame_shift_ms": 10,
"search_beam": 20, "search_beam": 20, # for k2 fsa composition
"output_beam": 8, "output_beam": 8, # for k2 fsa composition
"min_active_states": 30, "min_active_states": 30,
"max_active_states": 10000, "max_active_states": 10000,
"use_double_scores": True, "use_double_scores": True,
"beam": 4, # for prefix-beam-search
} }
) )
return params return params
@ -325,6 +405,9 @@ def decode_one_batch(
batch: dict, batch: dict,
word_table: k2.SymbolTable, word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
NNLM: Optional[LmScorer] = None,
LODR_lm: Optional[NgramLm] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -369,10 +452,7 @@ def decode_one_batch(
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. Note: If it decodes to nothing, then return None. the returned dict. Note: If it decodes to nothing, then return None.
""" """
if HLG is not None: device = params.device
device = HLG.device
else:
device = H.device
feature = batch["inputs"] feature = batch["inputs"]
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device) feature = feature.to(device)
@ -403,6 +483,51 @@ def decode_one_batch(
key = "ctc-greedy-search" key = "ctc-greedy-search"
return {key: hyps} return {key: hyps}
if params.decoding_method == "ctc-prefix-beam-search":
token_ids = ctc_prefix_beam_search(
ctc_output=ctc_output, encoder_out_lens=encoder_out_lens
)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "prefix-beam-search"
return {key: hyps}
if params.decoding_method == "ctc-prefix-beam-search-attention-decoder-rescoring":
best_path_dict = ctc_prefix_beam_search_attention_decoder_rescoring(
ctc_output=ctc_output,
attention_decoder=model.attention_decoder,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
ans = dict()
for a_scale_str, token_ids in best_path_dict.items():
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
ans[a_scale_str] = hyps
return ans
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
token_ids = ctc_prefix_beam_search_shallow_fussion(
ctc_output=ctc_output,
encoder_out_lens=encoder_out_lens,
NNLM=NNLM,
LODR_lm=LODR_lm,
LODR_lm_scale=params.lodr_lm_scale,
context_graph=context_graph,
)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "prefix-beam-search-shallow-fussion"
return {key: hyps}
supervision_segments = torch.stack( supervision_segments = torch.stack(
( (
supervisions["sequence_idx"], supervisions["sequence_idx"],
@ -455,7 +580,7 @@ def decode_one_batch(
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps] hyps = [s.split() for s in hyps]
key = "ctc-decoding" key = "ctc-decoding"
return {key: hyps} return {key: hyps} # note: returns words
if params.decoding_method == "attention-decoder-rescoring-no-ngram": if params.decoding_method == "attention-decoder-rescoring-no-ngram":
best_path_dict = rescore_with_attention_decoder_no_ngram( best_path_dict = rescore_with_attention_decoder_no_ngram(
@ -492,7 +617,7 @@ def decode_one_batch(
) )
hyps = get_texts(best_path) hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps] hyps = [[word_table[i] for i in ids] for ids in hyps]
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa key = f"oracle_{params.num_paths}_nbest-scale-{params.nbest_scale}" # noqa
return {key: hyps} return {key: hyps}
if params.decoding_method in ["1best", "nbest"]: if params.decoding_method in ["1best", "nbest"]:
@ -500,7 +625,7 @@ def decode_one_batch(
best_path = one_best_decoding( best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores lattice=lattice, use_double_scores=params.use_double_scores
) )
key = "no_rescore" key = "no-rescore"
else: else:
best_path = nbest_decoding( best_path = nbest_decoding(
lattice=lattice, lattice=lattice,
@ -508,11 +633,11 @@ def decode_one_batch(
use_double_scores=params.use_double_scores, use_double_scores=params.use_double_scores,
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
) )
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa key = f"no-rescore_nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
hyps = get_texts(best_path) hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps] hyps = [[word_table[i] for i in ids] for ids in hyps]
return {key: hyps} return {key: hyps} # note: returns BPE tokens
assert params.decoding_method in [ assert params.decoding_method in [
"nbest-rescoring", "nbest-rescoring",
@ -576,6 +701,9 @@ def decode_dataset(
bpe_model: Optional[spm.SentencePieceProcessor], bpe_model: Optional[spm.SentencePieceProcessor],
word_table: k2.SymbolTable, word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
NNLM: Optional[LmScorer] = None,
LODR_lm: Optional[NgramLm] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -626,6 +754,9 @@ def decode_dataset(
batch=batch, batch=batch,
word_table=word_table, word_table=word_table,
G=G, G=G,
NNLM=NNLM,
LODR_lm=LODR_lm,
context_graph=context_graph,
) )
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
@ -646,13 +777,32 @@ def decode_dataset(
return results return results
def save_results( def save_asr_output(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
"""
Save text produced by ASR.
"""
for key, results in results_dict.items():
recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recogs_filename, texts=results)
logging.info(f"The transcripts are stored in {recogs_filename}")
def save_wer_results(
params: AttributeDict, params: AttributeDict,
test_set_name: str, test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
): ):
if params.decoding_method in ( if params.decoding_method in (
"attention-decoder-rescoring-with-ngram", "whole-lattice-rescoring" "attention-decoder-rescoring-with-ngram",
"whole-lattice-rescoring",
): ):
# Set it to False since there are too many logs. # Set it to False since there are too many logs.
enable_log = False enable_log = False
@ -661,32 +811,30 @@ def save_results(
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs. # ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f: with open(errs_filename, "w", encoding="utf8") as fd:
wer = write_error_stats(f, f"{test_set_name}-{key}", results) wer = write_error_stats(
fd, f"{test_set_name}_{key}", results, enable_log=enable_log
)
test_set_wers[key] = wer test_set_wers[key] = wer
if enable_log:
logging.info("Wrote detailed error stats to {}".format(errs_filename)) logging.info(f"Wrote detailed error stats to {errs_filename}")
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name) wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
note = "\tbest for {}".format(test_set_name)
with open(wer_filename, "w", encoding="utf8") as fd:
print("settings\tWER", file=fd)
for key, val in test_set_wers:
print(f"{key}\t{val}", file=fd)
s = f"\nFor {test_set_name}, WER of different settings are:\n"
note = f"\tbest for {test_set_name}"
for key, val in test_set_wers: for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note) s += f"{key}\t{val}{note}\n"
note = "" note = ""
logging.info(s) logging.info(s)
@ -695,6 +843,7 @@ def save_results(
def main(): def main():
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser) LibriSpeechAsrDataModule.add_arguments(parser)
LmScorer.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir) args.lang_dir = Path(args.lang_dir)
@ -705,9 +854,15 @@ def main():
params.update(get_decoding_params()) params.update(get_decoding_params())
params.update(vars(args)) params.update(vars(args))
# enable AudioCache
set_caching_enabled(True) # lhotse
assert params.decoding_method in ( assert params.decoding_method in (
"ctc-greedy-search",
"ctc-decoding", "ctc-decoding",
"ctc-greedy-search",
"ctc-prefix-beam-search",
"ctc-prefix-beam-search-attention-decoder-rescoring",
"ctc-prefix-beam-search-shallow-fussion",
"1best", "1best",
"nbest", "nbest",
"nbest-rescoring", "nbest-rescoring",
@ -719,9 +874,9 @@ def main():
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0: if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}" params.suffix = f"iter-{params.iter}_avg-{params.avg}"
else: else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}_avg-{params.avg}"
if params.causal: if params.causal:
assert ( assert (
@ -730,11 +885,21 @@ def main():
assert ( assert (
"," not in params.left_context_frames "," not in params.left_context_frames
), "left_context_frames should be one value in decoding." ), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}" params.suffix += f"_chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}" params.suffix += f"_left-context-{params.left_context_frames}"
if "prefix-beam-search" in params.decoding_method:
params.suffix += f"_beam-{params.beam}"
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
if params.nnlm_scale != 0:
params.suffix += f"_nnlm-scale-{params.nnlm_scale}"
if params.lodr_lm_scale != 0:
params.suffix += f"_lodr-scale-{params.lodr_lm_scale}"
if params.context_score != 0:
params.suffix += f"_context_score-{params.context_score}"
if params.use_averaged_model: if params.use_averaged_model:
params.suffix += "-use-averaged-model" params.suffix += "_use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started") logging.info("Decoding started")
@ -742,6 +907,7 @@ def main():
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", 0) device = torch.device("cuda", 0)
params.device = device
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
logging.info(params) logging.info(params)
@ -757,14 +923,24 @@ def main():
params.sos_id = 1 params.sos_id = 1
if params.decoding_method in [ if params.decoding_method in [
"ctc-greedy-search", "ctc-decoding", "attention-decoder-rescoring-no-ngram" "ctc-decoding",
"ctc-greedy-search",
"ctc-prefix-beam-search",
"ctc-prefix-beam-search-attention-decoder-rescoring",
"ctc-prefix-beam-search-shallow-fussion",
"attention-decoder-rescoring-no-ngram",
]: ]:
HLG = None HLG = None
H = k2.ctc_topo( H = None
max_token=max_token_id, if params.decoding_method in [
modified=False, "ctc-decoding",
device=device, "attention-decoder-rescoring-no-ngram",
) ]:
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
bpe_model = spm.SentencePieceProcessor() bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model")) bpe_model.load(str(params.lang_dir / "bpe.model"))
else: else:
@ -815,7 +991,8 @@ def main():
G = k2.Fsa.from_dict(d) G = k2.Fsa.from_dict(d)
if params.decoding_method in [ if params.decoding_method in [
"whole-lattice-rescoring", "attention-decoder-rescoring-with-ngram" "whole-lattice-rescoring",
"attention-decoder-rescoring-with-ngram",
]: ]:
# Add epsilon self-loops to G as we will compose # Add epsilon self-loops to G as we will compose
# it with the whole lattice later # it with the whole lattice later
@ -829,6 +1006,51 @@ def main():
else: else:
G = None G = None
# only load the neural network LM if required
NNLM = None
if (
params.decoding_method == "ctc-prefix-beam-search-shallow-fussion"
and params.nnlm_scale != 0
):
NNLM = LmScorer(
lm_type=params.nnlm_type,
params=params,
device=device,
lm_scale=params.nnlm_scale,
)
NNLM.to(device)
NNLM.eval()
LODR_lm = None
if (
params.decoding_method == "ctc-prefix-beam-search-shallow-fussion"
and params.lodr_lm_scale != 0
):
assert os.path.exists(
params.lodr_ngram
), f"LODR ngram does not exists, given path : {params.lodr_ngram}"
logging.info(f"Loading LODR (token level lm): {params.lodr_ngram}")
LODR_lm = NgramLm(
params.lodr_ngram,
backoff_id=params.backoff_id,
is_binary=False,
)
logging.info(f"num states: {LODR_lm.lm.num_states}")
context_graph = None
if (
params.decoding_method == "ctc-prefix-beam-search-shallow-fussion"
and params.context_score != 0
):
assert os.path.exists(
params.context_file
), f"context_file does not exists, given path : {params.context_file}"
contexts = []
for line in open(params.context_file).readlines():
contexts.append(bpe_model.encode(line.strip()))
context_graph = ContextGraph(params.context_score)
context_graph.build(contexts)
logging.info("About to create model") logging.info("About to create model")
model = get_model(params) model = get_model(params)
@ -938,14 +1160,24 @@ def main():
bpe_model=bpe_model, bpe_model=bpe_model,
word_table=lexicon.word_table, word_table=lexicon.word_table,
G=G, G=G,
NNLM=NNLM,
LODR_lm=LODR_lm,
context_graph=context_graph,
) )
save_results( save_asr_output(
params=params, params=params,
test_set_name=test_set, test_set_name=test_set,
results_dict=results_dict, results_dict=results_dict,
) )
if not params.skip_scoring:
save_wer_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!") logging.info("Done!")

View File

@ -121,6 +121,7 @@ from beam_search import (
modified_beam_search_lm_shallow_fusion, modified_beam_search_lm_shallow_fusion,
modified_beam_search_LODR, modified_beam_search_LODR,
) )
from lhotse import set_caching_enabled
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
from icefall import ContextGraph, LmScorer, NgramLm from icefall import ContextGraph, LmScorer, NgramLm
@ -369,6 +370,14 @@ def get_parser():
modified_beam_search_LODR. modified_beam_search_LODR.
""", """,
) )
parser.add_argument(
"--skip-scoring",
type=str2bool,
default=False,
help="""Skip scoring, but still save the ASR output (for eval sets).""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -590,21 +599,23 @@ def decode_one_batch(
) )
hyps.append(sp.decode(hyp).split()) hyps.append(sp.decode(hyp).split())
# prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" )
prefix = f"{params.decoding_method}"
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method: elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_" prefix += f"_beam-{params.beam}"
key += f"max_contexts_{params.max_contexts}_" prefix += f"_max-contexts-{params.max_contexts}"
key += f"max_states_{params.max_states}" prefix += f"_max-states-{params.max_states}"
if "nbest" in params.decoding_method: if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_" prefix += f"_num-paths-{params.num_paths}"
key += f"nbest_scale_{params.nbest_scale}" prefix += f"_nbest-scale-{params.nbest_scale}"
if "LG" in params.decoding_method: if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}" prefix += f"_ngram-lm-scale-{params.ngram_lm_scale}"
return {key: hyps} return {prefix: hyps}
elif "modified_beam_search" in params.decoding_method: elif "modified_beam_search" in params.decoding_method:
prefix = f"beam_size_{params.beam_size}" prefix += f"_beam-size-{params.beam_size}"
if params.decoding_method in ( if params.decoding_method in (
"modified_beam_search_lm_rescore", "modified_beam_search_lm_rescore",
"modified_beam_search_lm_rescore_LODR", "modified_beam_search_lm_rescore_LODR",
@ -617,10 +628,11 @@ def decode_one_batch(
return ans return ans
else: else:
if params.has_contexts: if params.has_contexts:
prefix += f"-context-score-{params.context_score}" prefix += f"_context-score-{params.context_score}"
return {prefix: hyps} return {prefix: hyps}
else: else:
return {f"beam_size_{params.beam_size}": hyps} prefix += f"_beam-size-{params.beam_size}"
return {prefix: hyps}
def decode_dataset( def decode_dataset(
@ -707,46 +719,58 @@ def decode_dataset(
return results return results
def save_results( def save_asr_output(
params: AttributeDict, params: AttributeDict,
test_set_name: str, test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
): ):
"""
Save text produced by ASR.
"""
for key, results in results_dict.items():
recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recogs_filename, texts=results)
logging.info(f"The transcripts are stored in {recogs_filename}")
def save_wer_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str], Tuple]]],
):
"""
Save WER and per-utterance word alignments.
"""
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs. # ref/hyp pairs.
errs_filename = ( errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" with open(errs_filename, "w", encoding="utf8") as fd:
)
with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True fd, f"{test_set_name}-{key}", results, enable_log=True
) )
test_set_wers[key] = wer test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename)) logging.info(f"Wrote detailed error stats to {errs_filename}")
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name) wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
note = "\tbest for {}".format(test_set_name)
with open(wer_filename, "w", encoding="utf8") as fd:
print("settings\tWER", file=fd)
for key, val in test_set_wers:
print(f"{key}\t{val}", file=fd)
s = f"\nFor {test_set_name}, WER of different settings are:\n"
note = f"\tbest for {test_set_name}"
for key, val in test_set_wers: for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note) s += f"{key}\t{val}{note}\n"
note = "" note = ""
logging.info(s) logging.info(s)
@ -762,6 +786,9 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
# enable AudioCache
set_caching_enabled(True) # lhotse
assert params.decoding_method in ( assert params.decoding_method in (
"greedy_search", "greedy_search",
"beam_search", "beam_search",
@ -783,9 +810,9 @@ def main():
params.has_contexts = False params.has_contexts = False
if params.iter > 0: if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}" params.suffix = f"iter-{params.iter}_avg-{params.avg}"
else: else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}_avg-{params.avg}"
if params.causal: if params.causal:
assert ( assert (
@ -794,20 +821,20 @@ def main():
assert ( assert (
"," not in params.left_context_frames "," not in params.left_context_frames
), "left_context_frames should be one value in decoding." ), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}" params.suffix += f"_chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}" params.suffix += f"_left-context-{params.left_context_frames}"
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}" params.suffix += f"_beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"_max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"_max-states-{params.max_states}"
if "nbest" in params.decoding_method: if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}" params.suffix += f"_nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}" params.suffix += f"_num-paths-{params.num_paths}"
if "LG" in params.decoding_method: if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" params.suffix += f"_ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += f"__{params.decoding_method}__beam-size-{params.beam_size}"
if params.decoding_method in ( if params.decoding_method in (
"modified_beam_search", "modified_beam_search",
"modified_beam_search_LODR", "modified_beam_search_LODR",
@ -815,19 +842,19 @@ def main():
if params.has_contexts: if params.has_contexts:
params.suffix += f"-context-score-{params.context_score}" params.suffix += f"-context-score-{params.context_score}"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"_context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"_max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_shallow_fusion: if params.use_shallow_fusion:
params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" params.suffix += f"_{params.lm_type}-lm-scale-{params.lm_scale}"
if "LODR" in params.decoding_method: if "LODR" in params.decoding_method:
params.suffix += ( params.suffix += (
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
) )
if params.use_averaged_model: if params.use_averaged_model:
params.suffix += "-use-averaged-model" params.suffix += "_use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started") logging.info("Decoding started")
@ -1038,12 +1065,19 @@ def main():
ngram_lm_scale=ngram_lm_scale, ngram_lm_scale=ngram_lm_scale,
) )
save_results( save_asr_output(
params=params, params=params,
test_set_name=test_set, test_set_name=test_set,
results_dict=results_dict, results_dict=results_dict,
) )
if not params.skip_scoring:
save_wer_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!") logging.info("Done!")

View File

@ -74,7 +74,6 @@ import onnx
import torch import torch
import torch.nn as nn import torch.nn as nn
from decoder import Decoder from decoder import Decoder
from onnxconverter_common import float16
from onnxruntime.quantization import QuantType, quantize_dynamic from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
@ -488,6 +487,7 @@ def export_encoder_model_onnx(
add_meta_data(filename=encoder_filename, meta_data=meta_data) add_meta_data(filename=encoder_filename, meta_data=meta_data)
def export_decoder_model_onnx( def export_decoder_model_onnx(
decoder_model: OnnxDecoder, decoder_model: OnnxDecoder,
decoder_filename: str, decoder_filename: str,
@ -755,29 +755,31 @@ def main():
) )
logging.info(f"Exported joiner to {joiner_filename}") logging.info(f"Exported joiner to {joiner_filename}")
if(params.fp16) : if params.fp16:
from onnxconverter_common import float16
logging.info("Generate fp16 models") logging.info("Generate fp16 models")
encoder = onnx.load(encoder_filename) encoder = onnx.load(encoder_filename)
encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True) encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True)
encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx" encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx"
onnx.save(encoder_fp16,encoder_filename_fp16) onnx.save(encoder_fp16, encoder_filename_fp16)
decoder = onnx.load(decoder_filename) decoder = onnx.load(decoder_filename)
decoder_fp16 = float16.convert_float_to_float16(decoder, keep_io_types=True) decoder_fp16 = float16.convert_float_to_float16(decoder, keep_io_types=True)
decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx" decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx"
onnx.save(decoder_fp16,decoder_filename_fp16) onnx.save(decoder_fp16, decoder_filename_fp16)
joiner = onnx.load(joiner_filename) joiner = onnx.load(joiner_filename)
joiner_fp16 = float16.convert_float_to_float16(joiner, keep_io_types=True) joiner_fp16 = float16.convert_float_to_float16(joiner, keep_io_types=True)
joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx" joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx"
onnx.save(joiner_fp16,joiner_filename_fp16) onnx.save(joiner_fp16, joiner_filename_fp16)
# Generate int8 quantization models # Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
logging.info("Generate int8 quantization models") logging.info("Generate int8 quantization models")
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
quantize_dynamic( quantize_dynamic(
model_input=encoder_filename, model_input=encoder_filename,

Some files were not shown because too many files have changed in this diff Show More