mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 18:54:18 +00:00
Merge branch 'k2-fsa:master' into master
This commit is contained in:
commit
a82e0019ef
2
.flake8
2
.flake8
@ -15,7 +15,7 @@ per-file-ignores =
|
|||||||
egs/librispeech/ASR/zipformer_mmi/*.py: E501, E203
|
egs/librispeech/ASR/zipformer_mmi/*.py: E501, E203
|
||||||
egs/librispeech/ASR/zipformer/*.py: E501, E203
|
egs/librispeech/ASR/zipformer/*.py: E501, E203
|
||||||
egs/librispeech/ASR/RESULTS.md: E999,
|
egs/librispeech/ASR/RESULTS.md: E999,
|
||||||
|
egs/ljspeech/TTS/vits/*.py: E501, E203
|
||||||
# invalid escape sequence (cause by tex formular), W605
|
# invalid escape sequence (cause by tex formular), W605
|
||||||
icefall/utils.py: E501, W605
|
icefall/utils.py: E501, W605
|
||||||
|
|
||||||
|
158
.github/scripts/multi-zh-hans.sh
vendored
Executable file
158
.github/scripts/multi-zh-hans.sh
vendored
Executable file
@ -0,0 +1,158 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -ex
|
||||||
|
|
||||||
|
git config --global user.name "k2-fsa"
|
||||||
|
git config --global user.email "csukuangfj@gmail.com"
|
||||||
|
git config --global lfs.allowincompletepush true
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
log "pwd: $PWD"
|
||||||
|
|
||||||
|
cd egs/multi_zh-hans/ASR
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-streaming-2023-11-05
|
||||||
|
log "Downloading pre-trained model from $repo_url"
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
cd exp/
|
||||||
|
git lfs pull --include pretrained.pt
|
||||||
|
rm -fv epoch-20.pt
|
||||||
|
rm -fv *.onnx
|
||||||
|
ln -s pretrained.pt epoch-20.pt
|
||||||
|
cd ../data/lang_bpe_2000
|
||||||
|
ls -lh
|
||||||
|
git lfs pull --include L.pt L_disambig.pt Linv.pt bpe.model
|
||||||
|
git lfs pull --include "*.model"
|
||||||
|
ls -lh
|
||||||
|
popd
|
||||||
|
|
||||||
|
log "----------------------------------------"
|
||||||
|
log "Export streaming ONNX CTC models "
|
||||||
|
log "----------------------------------------"
|
||||||
|
./zipformer/export-onnx-streaming-ctc.py \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--tokens $repo/data/lang_bpe_2000/tokens.txt \
|
||||||
|
--causal 1 \
|
||||||
|
--avg 1 \
|
||||||
|
--epoch 20 \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--chunk-size 16 \
|
||||||
|
--left-context-frames 128 \
|
||||||
|
--use-ctc 1
|
||||||
|
|
||||||
|
ls -lh $repo/exp/
|
||||||
|
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
log "Test exported streaming ONNX CTC models (greedy search) "
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
|
||||||
|
test_wavs=(
|
||||||
|
DEV_T0000000000.wav
|
||||||
|
DEV_T0000000001.wav
|
||||||
|
DEV_T0000000002.wav
|
||||||
|
TEST_MEETING_T0000000113.wav
|
||||||
|
TEST_MEETING_T0000000219.wav
|
||||||
|
TEST_MEETING_T0000000351.wav
|
||||||
|
)
|
||||||
|
|
||||||
|
for w in ${test_wavs[@]}; do
|
||||||
|
./zipformer/onnx_pretrained-streaming-ctc.py \
|
||||||
|
--model-filename $repo/exp/ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx \
|
||||||
|
--tokens $repo/data/lang_bpe_2000/tokens.txt \
|
||||||
|
$repo/test_wavs/$w
|
||||||
|
done
|
||||||
|
|
||||||
|
log "Upload onnx CTC models to huggingface"
|
||||||
|
url=https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $url
|
||||||
|
dst=$(basename $url)
|
||||||
|
cp -v $repo/exp/ctc*.onnx $dst
|
||||||
|
cp -v $repo/data/lang_bpe_2000/tokens.txt $dst
|
||||||
|
cp -v $repo/data/lang_bpe_2000/bpe.model $dst
|
||||||
|
mkdir -p $dst/test_wavs
|
||||||
|
cp -v $repo/test_wavs/*.wav $dst/test_wavs
|
||||||
|
cd $dst
|
||||||
|
git lfs track "*.onnx" "bpe.model"
|
||||||
|
ls -lh
|
||||||
|
file bpe.model
|
||||||
|
git status
|
||||||
|
git add .
|
||||||
|
git commit -m "upload model" && git push https://k2-fsa:${HF_TOKEN}@huggingface.co/k2-fsa/$dst main || true
|
||||||
|
|
||||||
|
log "Upload models to https://github.com/k2-fsa/sherpa-onnx"
|
||||||
|
rm -rf .git
|
||||||
|
rm -fv .gitattributes
|
||||||
|
cd ..
|
||||||
|
tar cjfv $dst.tar.bz2 $dst
|
||||||
|
ls -lh *.tar.bz2
|
||||||
|
mv -v $dst.tar.bz2 ../../../
|
||||||
|
|
||||||
|
log "----------------------------------------"
|
||||||
|
log "Export streaming ONNX transducer models "
|
||||||
|
log "----------------------------------------"
|
||||||
|
|
||||||
|
./zipformer/export-onnx-streaming.py \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--tokens $repo/data/lang_bpe_2000/tokens.txt \
|
||||||
|
--causal 1 \
|
||||||
|
--avg 1 \
|
||||||
|
--epoch 20 \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--chunk-size 16 \
|
||||||
|
--left-context-frames 128 \
|
||||||
|
--use-ctc 0
|
||||||
|
|
||||||
|
ls -lh $repo/exp
|
||||||
|
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
log "Test exported streaming ONNX transducer models (Python code)"
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
|
||||||
|
log "test fp32"
|
||||||
|
./zipformer/onnx_pretrained-streaming.py \
|
||||||
|
--encoder-model-filename $repo/exp/encoder-epoch-20-avg-1-chunk-16-left-128.onnx \
|
||||||
|
--decoder-model-filename $repo/exp/decoder-epoch-20-avg-1-chunk-16-left-128.onnx \
|
||||||
|
--joiner-model-filename $repo/exp/joiner-epoch-20-avg-1-chunk-16-left-128.onnx \
|
||||||
|
--tokens $repo/data/lang_bpe_2000/tokens.txt \
|
||||||
|
$repo/test_wavs/DEV_T0000000000.wav
|
||||||
|
|
||||||
|
log "test int8"
|
||||||
|
./zipformer/onnx_pretrained-streaming.py \
|
||||||
|
--encoder-model-filename $repo/exp/encoder-epoch-20-avg-1-chunk-16-left-128.int8.onnx \
|
||||||
|
--decoder-model-filename $repo/exp/decoder-epoch-20-avg-1-chunk-16-left-128.onnx \
|
||||||
|
--joiner-model-filename $repo/exp/joiner-epoch-20-avg-1-chunk-16-left-128.int8.onnx \
|
||||||
|
--tokens $repo/data/lang_bpe_2000/tokens.txt \
|
||||||
|
$repo/test_wavs/DEV_T0000000000.wav
|
||||||
|
|
||||||
|
log "Upload onnx transducer models to huggingface"
|
||||||
|
|
||||||
|
url=https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-multi-zh-hans-2023-12-12
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $url
|
||||||
|
dst=$(basename $url)
|
||||||
|
cp -v $repo/exp/encoder*.onnx $dst
|
||||||
|
cp -v $repo/exp/decoder*.onnx $dst
|
||||||
|
cp -v $repo/exp/joiner*.onnx $dst
|
||||||
|
cp -v $repo/data/lang_bpe_2000/tokens.txt $dst
|
||||||
|
cp -v $repo/data/lang_bpe_2000/bpe.model $dst
|
||||||
|
mkdir -p $dst/test_wavs
|
||||||
|
cp -v $repo/test_wavs/*.wav $dst/test_wavs
|
||||||
|
cd $dst
|
||||||
|
git lfs track "*.onnx" bpe.model
|
||||||
|
git add .
|
||||||
|
git commit -m "upload model" && git push https://k2-fsa:${HF_TOKEN}@huggingface.co/k2-fsa/$dst main || true
|
||||||
|
|
||||||
|
log "Upload models to https://github.com/k2-fsa/sherpa-onnx"
|
||||||
|
rm -rf .git
|
||||||
|
rm -fv .gitattributes
|
||||||
|
cd ..
|
||||||
|
tar cjfv $dst.tar.bz2 $dst
|
||||||
|
ls -lh *.tar.bz2
|
||||||
|
mv -v $dst.tar.bz2 ../../../
|
@ -18,8 +18,8 @@ log "Downloading pre-commputed fbank from $fbank_url"
|
|||||||
git clone https://huggingface.co/csukuangfj/aishell-test-dev-manifests
|
git clone https://huggingface.co/csukuangfj/aishell-test-dev-manifests
|
||||||
ln -s $PWD/aishell-test-dev-manifests/data .
|
ln -s $PWD/aishell-test-dev-manifests/data .
|
||||||
|
|
||||||
log "Downloading pre-trained model from $repo_url"
|
|
||||||
repo_url=https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20
|
repo_url=https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20
|
||||||
|
log "Downloading pre-trained model from $repo_url"
|
||||||
git clone $repo_url
|
git clone $repo_url
|
||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
103
.github/scripts/run-aishell-zipformer-2023-10-24.sh
vendored
Executable file
103
.github/scripts/run-aishell-zipformer-2023-10-24.sh
vendored
Executable file
@ -0,0 +1,103 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
cd egs/aishell/ASR
|
||||||
|
|
||||||
|
git lfs install
|
||||||
|
|
||||||
|
fbank_url=https://huggingface.co/csukuangfj/aishell-test-dev-manifests
|
||||||
|
log "Downloading pre-commputed fbank from $fbank_url"
|
||||||
|
|
||||||
|
git clone https://huggingface.co/csukuangfj/aishell-test-dev-manifests
|
||||||
|
ln -s $PWD/aishell-test-dev-manifests/data .
|
||||||
|
|
||||||
|
log "======================="
|
||||||
|
log "CI testing large model"
|
||||||
|
repo_url=https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-large-2023-10-24/
|
||||||
|
log "Downloading pre-trained model from $repo_url"
|
||||||
|
git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
log "Display test files"
|
||||||
|
tree $repo/
|
||||||
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
|
for method in modified_beam_search greedy_search fast_beam_search; do
|
||||||
|
log "$method"
|
||||||
|
|
||||||
|
./zipformer/pretrained.py \
|
||||||
|
--method $method \
|
||||||
|
--context-size 1 \
|
||||||
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
|
--tokens $repo/data/lang_char/tokens.txt \
|
||||||
|
--num-encoder-layers 2,2,4,5,4,2 \
|
||||||
|
--feedforward-dim 512,768,1536,2048,1536,768 \
|
||||||
|
--encoder-dim 192,256,512,768,512,256 \
|
||||||
|
--encoder-unmasked-dim 192,192,256,320,256,192 \
|
||||||
|
$repo/test_wavs/BAC009S0764W0121.wav \
|
||||||
|
$repo/test_wavs/BAC009S0764W0122.wav \
|
||||||
|
$repo/test_wavs/BAC009S0764W0123.wav
|
||||||
|
done
|
||||||
|
|
||||||
|
log "======================="
|
||||||
|
log "CI testing medium model"
|
||||||
|
repo_url=https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-2023-10-24/
|
||||||
|
log "Downloading pre-trained model from $repo_url"
|
||||||
|
git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
log "Display test files"
|
||||||
|
tree $repo/
|
||||||
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
|
|
||||||
|
for method in modified_beam_search greedy_search fast_beam_search; do
|
||||||
|
log "$method"
|
||||||
|
|
||||||
|
./zipformer/pretrained.py \
|
||||||
|
--method $method \
|
||||||
|
--context-size 1 \
|
||||||
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
|
--tokens $repo/data/lang_char/tokens.txt \
|
||||||
|
$repo/test_wavs/BAC009S0764W0121.wav \
|
||||||
|
$repo/test_wavs/BAC009S0764W0122.wav \
|
||||||
|
$repo/test_wavs/BAC009S0764W0123.wav
|
||||||
|
done
|
||||||
|
|
||||||
|
|
||||||
|
log "======================="
|
||||||
|
log "CI testing small model"
|
||||||
|
repo_url=https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-small-2023-10-24/
|
||||||
|
log "Downloading pre-trained model from $repo_url"
|
||||||
|
git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
log "Display test files"
|
||||||
|
tree $repo/
|
||||||
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
|
|
||||||
|
for method in modified_beam_search greedy_search fast_beam_search; do
|
||||||
|
log "$method"
|
||||||
|
|
||||||
|
./zipformer/pretrained.py \
|
||||||
|
--method $method \
|
||||||
|
--context-size 1 \
|
||||||
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
|
--tokens $repo/data/lang_char/tokens.txt \
|
||||||
|
--num-encoder-layers 2,2,2,2,2,2 \
|
||||||
|
--feedforward-dim 512,768,768,768,768,768 \
|
||||||
|
--encoder-dim 192,256,256,256,256,256 \
|
||||||
|
--encoder-unmasked-dim 192,192,192,192,192,192 \
|
||||||
|
$repo/test_wavs/BAC009S0764W0121.wav \
|
||||||
|
$repo/test_wavs/BAC009S0764W0122.wav \
|
||||||
|
$repo/test_wavs/BAC009S0764W0123.wav
|
||||||
|
done
|
||||||
|
|
158
.github/scripts/run-gigaspeech-zipformer-2023-10-17.sh
vendored
Executable file
158
.github/scripts/run-gigaspeech-zipformer-2023-10-17.sh
vendored
Executable file
@ -0,0 +1,158 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
cd egs/gigaspeech/ASR
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/yfyeung/icefall-asr-gigaspeech-zipformer-2023-10-17
|
||||||
|
|
||||||
|
log "Downloading pre-trained model from $repo_url"
|
||||||
|
git lfs install
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
log "Display test files"
|
||||||
|
tree $repo/
|
||||||
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
|
pushd $repo/exp
|
||||||
|
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||||
|
git lfs pull --include "data/lang_bpe_500/tokens.txt"
|
||||||
|
git lfs pull --include "exp/jit_script.pt"
|
||||||
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
rm epoch-30.pt
|
||||||
|
ln -s pretrained.pt epoch-30.pt
|
||||||
|
rm *.onnx
|
||||||
|
ls -lh
|
||||||
|
popd
|
||||||
|
|
||||||
|
log "----------------------------------------"
|
||||||
|
log "Export ONNX transducer models "
|
||||||
|
log "----------------------------------------"
|
||||||
|
|
||||||
|
./zipformer/export-onnx.py \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir $repo/exp
|
||||||
|
|
||||||
|
ls -lh $repo/exp
|
||||||
|
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
log "Test exported ONNX transducer models (Python code) "
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
|
||||||
|
log "test fp32"
|
||||||
|
./zipformer/onnx_pretrained.py \
|
||||||
|
--encoder-model-filename $repo/exp/encoder-epoch-30-avg-1.onnx \
|
||||||
|
--decoder-model-filename $repo/exp/decoder-epoch-30-avg-1.onnx \
|
||||||
|
--joiner-model-filename $repo/exp/joiner-epoch-30-avg-1.onnx \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
|
||||||
|
log "test int8"
|
||||||
|
./zipformer/onnx_pretrained.py \
|
||||||
|
--encoder-model-filename $repo/exp/encoder-epoch-30-avg-1.int8.onnx \
|
||||||
|
--decoder-model-filename $repo/exp/decoder-epoch-30-avg-1.onnx \
|
||||||
|
--joiner-model-filename $repo/exp/joiner-epoch-30-avg-1.int8.onnx \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
|
||||||
|
log "Upload models to huggingface"
|
||||||
|
git config --global user.name "k2-fsa"
|
||||||
|
git config --global user.email "xxx@gmail.com"
|
||||||
|
|
||||||
|
url=https://huggingface.co/k2-fsa/sherpa-onnx-zipformer-gigaspeech-2023-12-12
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $url
|
||||||
|
dst=$(basename $url)
|
||||||
|
cp -v $repo/exp/*.onnx $dst
|
||||||
|
cp -v $repo/data/lang_bpe_500/tokens.txt $dst
|
||||||
|
cp -v $repo/data/lang_bpe_500/bpe.model $dst
|
||||||
|
mkdir -p $dst/test_wavs
|
||||||
|
cp -v $repo/test_wavs/*.wav $dst/test_wavs
|
||||||
|
cd $dst
|
||||||
|
git lfs track "*.onnx"
|
||||||
|
git add .
|
||||||
|
git commit -m "upload model" && git push https://k2-fsa:${HF_TOKEN}@huggingface.co/k2-fsa/$dst main || true
|
||||||
|
|
||||||
|
log "Upload models to https://github.com/k2-fsa/sherpa-onnx"
|
||||||
|
rm -rf .git
|
||||||
|
rm -fv .gitattributes
|
||||||
|
cd ..
|
||||||
|
tar cjfv $dst.tar.bz2 $dst
|
||||||
|
ls -lh
|
||||||
|
mv -v $dst.tar.bz2 ../../../
|
||||||
|
|
||||||
|
log "Export to torchscript model"
|
||||||
|
./zipformer/export.py \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--use-averaged-model false \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 1 \
|
||||||
|
--jit 1
|
||||||
|
|
||||||
|
ls -lh $repo/exp/*.pt
|
||||||
|
|
||||||
|
log "Decode with models exported by torch.jit.script()"
|
||||||
|
|
||||||
|
./zipformer/jit_pretrained.py \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
--nn-model-filename $repo/exp/jit_script.pt \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
|
||||||
|
for method in greedy_search modified_beam_search fast_beam_search; do
|
||||||
|
log "$method"
|
||||||
|
|
||||||
|
./zipformer/pretrained.py \
|
||||||
|
--method $method \
|
||||||
|
--beam-size 4 \
|
||||||
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
|
||||||
|
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
|
||||||
|
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
|
||||||
|
mkdir -p zipformer/exp
|
||||||
|
ln -s $PWD/$repo/exp/pretrained.pt zipformer/exp/epoch-30.pt
|
||||||
|
ln -s $PWD/$repo/data/lang_bpe_500 data/
|
||||||
|
|
||||||
|
ls -lh data
|
||||||
|
ls -lh zipformer/exp
|
||||||
|
|
||||||
|
log "Decoding test-clean and test-other"
|
||||||
|
|
||||||
|
# use a small value for decoding with CPU
|
||||||
|
max_duration=100
|
||||||
|
|
||||||
|
for method in greedy_search fast_beam_search modified_beam_search; do
|
||||||
|
log "Decoding with $method"
|
||||||
|
|
||||||
|
./zipformer/decode.py \
|
||||||
|
--decoding-method $method \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 1 \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--max-duration $max_duration \
|
||||||
|
--exp-dir zipformer/exp
|
||||||
|
done
|
||||||
|
|
||||||
|
rm zipformer/exp/*.pt
|
||||||
|
fi
|
135
.github/scripts/run-multi-corpora-zipformer.sh
vendored
Executable file
135
.github/scripts/run-multi-corpora-zipformer.sh
vendored
Executable file
@ -0,0 +1,135 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
cd egs/multi_zh-hans/ASR
|
||||||
|
|
||||||
|
log "==== Test icefall-asr-multi-zh-hans-zipformer-2023-9-2 ===="
|
||||||
|
repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/
|
||||||
|
|
||||||
|
log "Downloading pre-trained model from $repo_url"
|
||||||
|
git lfs install
|
||||||
|
git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
|
||||||
|
log "Display test files"
|
||||||
|
tree $repo/
|
||||||
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
|
pushd $repo/exp
|
||||||
|
ln -s epoch-20.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
ls -lh $repo/exp/*.pt
|
||||||
|
|
||||||
|
|
||||||
|
./zipformer/pretrained.py \
|
||||||
|
--checkpoint $repo/exp/epoch-99.pt \
|
||||||
|
--tokens $repo/data/lang_bpe_2000/tokens.txt \
|
||||||
|
--method greedy_search \
|
||||||
|
$repo/test_wavs/DEV_T0000000000.wav \
|
||||||
|
$repo/test_wavs/DEV_T0000000001.wav \
|
||||||
|
$repo/test_wavs/DEV_T0000000002.wav
|
||||||
|
|
||||||
|
for method in modified_beam_search fast_beam_search; do
|
||||||
|
log "$method"
|
||||||
|
|
||||||
|
./zipformer/pretrained.py \
|
||||||
|
--method $method \
|
||||||
|
--beam-size 4 \
|
||||||
|
--checkpoint $repo/exp/epoch-99.pt \
|
||||||
|
--tokens $repo/data/lang_bpe_2000/tokens.txt \
|
||||||
|
$repo/test_wavs/DEV_T0000000000.wav \
|
||||||
|
$repo/test_wavs/DEV_T0000000001.wav \
|
||||||
|
$repo/test_wavs/DEV_T0000000002.wav
|
||||||
|
done
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
|
|
||||||
|
log "==== Test icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24 ===="
|
||||||
|
repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/
|
||||||
|
|
||||||
|
log "Downloading pre-trained model from $repo_url"
|
||||||
|
git lfs install
|
||||||
|
git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
|
||||||
|
log "Display test files"
|
||||||
|
tree $repo/
|
||||||
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
|
pushd $repo/exp
|
||||||
|
ln -s epoch-20.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
ls -lh $repo/exp/*.pt
|
||||||
|
|
||||||
|
|
||||||
|
./zipformer/pretrained.py \
|
||||||
|
--checkpoint $repo/exp/epoch-99.pt \
|
||||||
|
--tokens $repo/data/lang_bpe_2000/tokens.txt \
|
||||||
|
--use-ctc 1 \
|
||||||
|
--method greedy_search \
|
||||||
|
$repo/test_wavs/DEV_T0000000000.wav \
|
||||||
|
$repo/test_wavs/DEV_T0000000001.wav \
|
||||||
|
$repo/test_wavs/DEV_T0000000002.wav
|
||||||
|
|
||||||
|
for method in modified_beam_search fast_beam_search; do
|
||||||
|
log "$method"
|
||||||
|
|
||||||
|
./zipformer/pretrained.py \
|
||||||
|
--method $method \
|
||||||
|
--beam-size 4 \
|
||||||
|
--use-ctc 1 \
|
||||||
|
--checkpoint $repo/exp/epoch-99.pt \
|
||||||
|
--tokens $repo/data/lang_bpe_2000/tokens.txt \
|
||||||
|
$repo/test_wavs/DEV_T0000000000.wav \
|
||||||
|
$repo/test_wavs/DEV_T0000000001.wav \
|
||||||
|
$repo/test_wavs/DEV_T0000000002.wav
|
||||||
|
done
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
|
|
||||||
|
cd ../../../egs/multi_zh_en/ASR
|
||||||
|
log "==== Test icefall-asr-zipformer-multi-zh-en-2023-11-22 ===="
|
||||||
|
repo_url=https://huggingface.co/zrjin/icefall-asr-zipformer-multi-zh-en-2023-11-22/
|
||||||
|
|
||||||
|
log "Downloading pre-trained model from $repo_url"
|
||||||
|
git lfs install
|
||||||
|
git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
log "Display test files"
|
||||||
|
tree $repo/
|
||||||
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
|
./zipformer/pretrained.py \
|
||||||
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
|
--bpe-model $repo/data/lang_bbpe_2000/bbpe.model \
|
||||||
|
--method greedy_search \
|
||||||
|
$repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_29.wav \
|
||||||
|
$repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_55.wav \
|
||||||
|
$repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_75.wav
|
||||||
|
|
||||||
|
for method in modified_beam_search fast_beam_search; do
|
||||||
|
log "$method"
|
||||||
|
|
||||||
|
./zipformer/pretrained.py \
|
||||||
|
--method $method \
|
||||||
|
--beam-size 4 \
|
||||||
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
|
--bpe-model $repo/data/lang_bbpe_2000/bbpe.model \
|
||||||
|
$repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_29.wav \
|
||||||
|
$repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_55.wav \
|
||||||
|
$repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_75.wav
|
||||||
|
done
|
||||||
|
|
||||||
|
rm -rf $repo
|
51
.github/scripts/run-multi-zh_hans-zipformer.sh
vendored
51
.github/scripts/run-multi-zh_hans-zipformer.sh
vendored
@ -1,51 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
|
|
||||||
set -e
|
|
||||||
|
|
||||||
log() {
|
|
||||||
# This function is from espnet
|
|
||||||
local fname=${BASH_SOURCE[1]##*/}
|
|
||||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
|
||||||
}
|
|
||||||
|
|
||||||
cd egs/multi_zh-hans/ASR
|
|
||||||
|
|
||||||
repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/
|
|
||||||
|
|
||||||
log "Downloading pre-trained model from $repo_url"
|
|
||||||
git lfs install
|
|
||||||
git clone $repo_url
|
|
||||||
repo=$(basename $repo_url)
|
|
||||||
|
|
||||||
|
|
||||||
log "Display test files"
|
|
||||||
tree $repo/
|
|
||||||
ls -lh $repo/test_wavs/*.wav
|
|
||||||
|
|
||||||
pushd $repo/exp
|
|
||||||
ln -s epoch-20.pt epoch-99.pt
|
|
||||||
popd
|
|
||||||
|
|
||||||
ls -lh $repo/exp/*.pt
|
|
||||||
|
|
||||||
|
|
||||||
./zipformer/pretrained.py \
|
|
||||||
--checkpoint $repo/exp/epoch-99.pt \
|
|
||||||
--tokens $repo/data/lang_bpe_2000/tokens.txt \
|
|
||||||
--method greedy_search \
|
|
||||||
$repo/test_wavs/DEV_T0000000000.wav \
|
|
||||||
$repo/test_wavs/DEV_T0000000001.wav \
|
|
||||||
$repo/test_wavs/DEV_T0000000002.wav
|
|
||||||
|
|
||||||
for method in modified_beam_search fast_beam_search; do
|
|
||||||
log "$method"
|
|
||||||
|
|
||||||
./zipformer/pretrained.py \
|
|
||||||
--method $method \
|
|
||||||
--beam-size 4 \
|
|
||||||
--checkpoint $repo/exp/epoch-99.pt \
|
|
||||||
--tokens $repo/data/lang_bpe_2000/tokens.txt \
|
|
||||||
$repo/test_wavs/DEV_T0000000000.wav \
|
|
||||||
$repo/test_wavs/DEV_T0000000001.wav \
|
|
||||||
$repo/test_wavs/DEV_T0000000002.wav
|
|
||||||
done
|
|
44
.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh
vendored
Executable file
44
.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh
vendored
Executable file
@ -0,0 +1,44 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
cd egs/swbd/ASR
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/zrjin/icefall-asr-swbd-conformer-ctc-2023-8-26
|
||||||
|
|
||||||
|
log "Downloading pre-trained model from $repo_url"
|
||||||
|
git lfs install
|
||||||
|
git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
|
||||||
|
log "Display test files"
|
||||||
|
tree $repo/
|
||||||
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
|
pushd $repo/exp
|
||||||
|
ln -s epoch-98.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
ls -lh $repo/exp/*.pt
|
||||||
|
|
||||||
|
for method in ctc-decoding 1best; do
|
||||||
|
log "$method"
|
||||||
|
|
||||||
|
./conformer_ctc/pretrained.py \
|
||||||
|
--method $method \
|
||||||
|
--checkpoint $repo/exp/epoch-99.pt \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
--words-file $repo/data/lang_bpe_500/words.txt \
|
||||||
|
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
||||||
|
--G $repo/data/lm/G_4_gram.pt \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
done
|
9
.github/workflows/build-docker-image.yml
vendored
9
.github/workflows/build-docker-image.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
image: ["torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
|
image: ["torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
# refer to https://github.com/actions/checkout
|
# refer to https://github.com/actions/checkout
|
||||||
@ -30,6 +30,13 @@ jobs:
|
|||||||
image=${{ matrix.image }}
|
image=${{ matrix.image }}
|
||||||
mv -v ./docker/$image.dockerfile ./Dockerfile
|
mv -v ./docker/$image.dockerfile ./Dockerfile
|
||||||
|
|
||||||
|
- name: Free space
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
df -h
|
||||||
|
rm -rf /opt/hostedtoolcache
|
||||||
|
df -h
|
||||||
|
|
||||||
- name: Log in to Docker Hub
|
- name: Log in to Docker Hub
|
||||||
uses: docker/login-action@v2
|
uses: docker/login-action@v2
|
||||||
with:
|
with:
|
||||||
|
79
.github/workflows/multi-zh-hans.yml
vendored
Normal file
79
.github/workflows/multi-zh-hans.yml
vendored
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
name: run-multi-zh-hans
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: run-multi-zh-hans-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
multi-zh-hans:
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest]
|
||||||
|
python-version: [3.8]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Setup Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
cache: 'pip'
|
||||||
|
cache-dependency-path: '**/requirements-ci.txt'
|
||||||
|
|
||||||
|
- name: Install Python dependencies
|
||||||
|
run: |
|
||||||
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
|
pip uninstall -y protobuf
|
||||||
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
|
- name: Cache kaldifeat
|
||||||
|
id: my-cache
|
||||||
|
uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/tmp/kaldifeat
|
||||||
|
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||||
|
|
||||||
|
- name: Install kaldifeat
|
||||||
|
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
.github/scripts/install-kaldifeat.sh
|
||||||
|
|
||||||
|
- name: export-model
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||||
|
run: |
|
||||||
|
sudo apt-get -qq install git-lfs tree
|
||||||
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|
||||||
|
.github/scripts/multi-zh-hans.sh
|
||||||
|
ls -lh
|
||||||
|
|
||||||
|
- name: upload model to https://github.com/k2-fsa/sherpa-onnx
|
||||||
|
uses: svenstaro/upload-release-action@v2
|
||||||
|
with:
|
||||||
|
file_glob: true
|
||||||
|
file: ./*.tar.bz2
|
||||||
|
overwrite: true
|
||||||
|
repo_name: k2-fsa/sherpa-onnx
|
||||||
|
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
|
||||||
|
tag: asr-models
|
95
.github/workflows/run-aishell-zipformer-2023-10-24.yml
vendored
Normal file
95
.github/workflows/run-aishell-zipformer-2023-10-24.yml
vendored
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
# Copyright 2023 Zengrui Jin (Xiaomi Corp.)
|
||||||
|
|
||||||
|
# See ../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
name: run-aishell-zipformer-2023-10-24
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
pull_request:
|
||||||
|
types: [labeled]
|
||||||
|
|
||||||
|
schedule:
|
||||||
|
# minute (0-59)
|
||||||
|
# hour (0-23)
|
||||||
|
# day of the month (1-31)
|
||||||
|
# month (1-12)
|
||||||
|
# day of the week (0-6)
|
||||||
|
# nightly build at 15:50 UTC time every day
|
||||||
|
- cron: "50 15 * * *"
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: run_aishell_zipformer_2023_10_24-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
run_aishell_zipformer_2023_10_24:
|
||||||
|
if: github.event.label.name == 'ready' || github.event.label.name == 'zipformer' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest]
|
||||||
|
python-version: [3.8]
|
||||||
|
|
||||||
|
fail-fast: false
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Setup Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
cache: 'pip'
|
||||||
|
cache-dependency-path: '**/requirements-ci.txt'
|
||||||
|
|
||||||
|
- name: Install Python dependencies
|
||||||
|
run: |
|
||||||
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
|
pip uninstall -y protobuf
|
||||||
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
|
- name: Cache kaldifeat
|
||||||
|
id: my-cache
|
||||||
|
uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/tmp/kaldifeat
|
||||||
|
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||||
|
|
||||||
|
- name: Install kaldifeat
|
||||||
|
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
.github/scripts/install-kaldifeat.sh
|
||||||
|
|
||||||
|
- name: Inference with pre-trained model
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
||||||
|
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
|
||||||
|
run: |
|
||||||
|
sudo apt-get -qq install git-lfs tree
|
||||||
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|
||||||
|
.github/scripts/run-aishell-zipformer-2023-10-24.sh
|
||||||
|
|
||||||
|
|
15
.github/workflows/run-docker-image.yml
vendored
15
.github/workflows/run-docker-image.yml
vendored
@ -14,7 +14,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
image: ["torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
|
image: ["torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
|
||||||
steps:
|
steps:
|
||||||
# refer to https://github.com/actions/checkout
|
# refer to https://github.com/actions/checkout
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
@ -30,8 +30,15 @@ jobs:
|
|||||||
uname -a
|
uname -a
|
||||||
cat /etc/*release
|
cat /etc/*release
|
||||||
|
|
||||||
|
find / -name libcuda* 2>/dev/null
|
||||||
|
|
||||||
|
ls -lh /usr/local/
|
||||||
|
ls -lh /usr/local/cuda*
|
||||||
|
|
||||||
nvcc --version
|
nvcc --version
|
||||||
|
|
||||||
|
ls -lh /usr/local/cuda-*/compat/*
|
||||||
|
|
||||||
# For torch1.9.0-cuda10.2
|
# For torch1.9.0-cuda10.2
|
||||||
export LD_LIBRARY_PATH=/usr/local/cuda-10.2/compat:$LD_LIBRARY_PATH
|
export LD_LIBRARY_PATH=/usr/local/cuda-10.2/compat:$LD_LIBRARY_PATH
|
||||||
|
|
||||||
@ -41,6 +48,12 @@ jobs:
|
|||||||
# For torch2.0.0-cuda11.7
|
# For torch2.0.0-cuda11.7
|
||||||
export LD_LIBRARY_PATH=/usr/local/cuda-11.7/compat:$LD_LIBRARY_PATH
|
export LD_LIBRARY_PATH=/usr/local/cuda-11.7/compat:$LD_LIBRARY_PATH
|
||||||
|
|
||||||
|
# For torch2.1.0-cuda11.8
|
||||||
|
export LD_LIBRARY_PATH=/usr/local/cuda-11.8/compat:$LD_LIBRARY_PATH
|
||||||
|
|
||||||
|
# For torch2.1.0-cuda12.1
|
||||||
|
export LD_LIBRARY_PATH=/usr/local/cuda-12.1/compat:$LD_LIBRARY_PATH
|
||||||
|
|
||||||
|
|
||||||
which nvcc
|
which nvcc
|
||||||
cuda_dir=$(dirname $(which nvcc))
|
cuda_dir=$(dirname $(which nvcc))
|
||||||
|
140
.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml
vendored
Normal file
140
.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml
vendored
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com)
|
||||||
|
|
||||||
|
# See ../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
name: run-gigaspeech-zipformer-2023-10-17
|
||||||
|
# zipformer
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
|
||||||
|
pull_request:
|
||||||
|
types: [labeled]
|
||||||
|
|
||||||
|
schedule:
|
||||||
|
# minute (0-59)
|
||||||
|
# hour (0-23)
|
||||||
|
# day of the month (1-31)
|
||||||
|
# month (1-12)
|
||||||
|
# day of the week (0-6)
|
||||||
|
# nightly build at 15:50 UTC time every day
|
||||||
|
- cron: "50 15 * * *"
|
||||||
|
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: run_gigaspeech_2023_10_17_zipformer-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
run_gigaspeech_2023_10_17_zipformer:
|
||||||
|
if: github.event.label.name == 'zipformer' ||github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest]
|
||||||
|
python-version: [3.8]
|
||||||
|
|
||||||
|
fail-fast: false
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Setup Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
cache: 'pip'
|
||||||
|
cache-dependency-path: '**/requirements-ci.txt'
|
||||||
|
|
||||||
|
- name: Install Python dependencies
|
||||||
|
run: |
|
||||||
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
|
pip uninstall -y protobuf
|
||||||
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
|
- name: Cache kaldifeat
|
||||||
|
id: my-cache
|
||||||
|
uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/tmp/kaldifeat
|
||||||
|
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||||
|
|
||||||
|
- name: Install kaldifeat
|
||||||
|
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
.github/scripts/install-kaldifeat.sh
|
||||||
|
|
||||||
|
- name: Inference with pre-trained model
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
||||||
|
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
|
||||||
|
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||||
|
run: |
|
||||||
|
mkdir -p egs/gigaspeech/ASR/data
|
||||||
|
ln -sfv ~/tmp/fbank-libri egs/gigaspeech/ASR/data/fbank
|
||||||
|
ls -lh egs/gigaspeech/ASR/data/*
|
||||||
|
|
||||||
|
sudo apt-get -qq install git-lfs tree
|
||||||
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|
||||||
|
.github/scripts/run-gigaspeech-zipformer-2023-10-17.sh
|
||||||
|
|
||||||
|
- name: upload model to https://github.com/k2-fsa/sherpa-onnx
|
||||||
|
uses: svenstaro/upload-release-action@v2
|
||||||
|
with:
|
||||||
|
file_glob: true
|
||||||
|
file: ./*.tar.bz2
|
||||||
|
overwrite: true
|
||||||
|
repo_name: k2-fsa/sherpa-onnx
|
||||||
|
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
|
||||||
|
tag: asr-models
|
||||||
|
|
||||||
|
- name: Display decoding results for gigaspeech zipformer
|
||||||
|
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
cd egs/gigaspeech/ASR/
|
||||||
|
tree ./zipformer/exp
|
||||||
|
|
||||||
|
cd zipformer
|
||||||
|
echo "results for zipformer"
|
||||||
|
echo "===greedy search==="
|
||||||
|
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||||
|
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||||
|
|
||||||
|
echo "===fast_beam_search==="
|
||||||
|
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||||
|
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||||
|
|
||||||
|
echo "===modified beam search==="
|
||||||
|
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
|
||||||
|
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
|
||||||
|
|
||||||
|
- name: Upload decoding results for gigaspeech zipformer
|
||||||
|
uses: actions/upload-artifact@v2
|
||||||
|
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||||
|
with:
|
||||||
|
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer-2022-11-11
|
||||||
|
path: egs/gigaspeech/ASR/zipformer/exp/
|
84
.github/workflows/run-multi-corpora-zipformer.yml
vendored
Normal file
84
.github/workflows/run-multi-corpora-zipformer.yml
vendored
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
# Copyright 2023 Xiaomi Corp. (author: Zengrui Jin)
|
||||||
|
|
||||||
|
# See ../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
name: run-multi-corpora-zipformer
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
pull_request:
|
||||||
|
types: [labeled]
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: run_multi-corpora_zipformer-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
run_multi-corpora_zipformer:
|
||||||
|
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'multi-zh_hans' || github.event.label.name == 'zipformer' || github.event.label.name == 'multi-corpora'
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest]
|
||||||
|
python-version: [3.8]
|
||||||
|
|
||||||
|
fail-fast: false
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Setup Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
cache: 'pip'
|
||||||
|
cache-dependency-path: '**/requirements-ci.txt'
|
||||||
|
|
||||||
|
- name: Install Python dependencies
|
||||||
|
run: |
|
||||||
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
|
pip uninstall -y protobuf
|
||||||
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
|
- name: Cache kaldifeat
|
||||||
|
id: my-cache
|
||||||
|
uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/tmp/kaldifeat
|
||||||
|
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||||
|
|
||||||
|
- name: Install kaldifeat
|
||||||
|
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
.github/scripts/install-kaldifeat.sh
|
||||||
|
|
||||||
|
- name: Inference with pre-trained model
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
||||||
|
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
|
||||||
|
run: |
|
||||||
|
sudo apt-get -qq install git-lfs tree
|
||||||
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|
||||||
|
.github/scripts/run-multi-corpora-zipformer.sh
|
@ -14,7 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
name: run-multi-zh_hans-zipformer
|
name: run-swbd-conformer_ctc
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
@ -24,12 +24,12 @@ on:
|
|||||||
types: [labeled]
|
types: [labeled]
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: run_multi-zh_hans_zipformer-${{ github.ref }}
|
group: run-swbd-conformer_ctc-${{ github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
run_multi-zh_hans_zipformer:
|
run-swbd-conformer_ctc:
|
||||||
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'multi-zh_hans'
|
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'swbd'
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
@ -81,4 +81,4 @@ jobs:
|
|||||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|
||||||
.github/scripts/run-multi-zh_hans-zipformer.sh
|
.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh
|
4
.github/workflows/run-yesno-recipe.yml
vendored
4
.github/workflows/run-yesno-recipe.yml
vendored
@ -64,8 +64,8 @@ jobs:
|
|||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf==3.20.*
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
pip install --no-deps --force-reinstall https://huggingface.co/csukuangfj/k2/resolve/main/cpu/k2-1.24.3.dev20230508+cpu.torch1.13.1-cp38-cp38-linux_x86_64.whl
|
pip install --no-deps --force-reinstall k2==1.24.4.dev20231021+cpu.torch1.13.1 -f https://k2-fsa.github.io/k2/cpu.html
|
||||||
pip install kaldifeat==1.25.0.dev20230726+cpu.torch1.13.1 -f https://csukuangfj.github.io/kaldifeat/cpu.html
|
pip install kaldifeat==1.25.1.dev20231022+cpu.torch1.13.1 -f https://csukuangfj.github.io/kaldifeat/cpu.html
|
||||||
|
|
||||||
- name: Run yesno recipe
|
- name: Run yesno recipe
|
||||||
shell: bash
|
shell: bash
|
||||||
|
33
README.md
33
README.md
@ -29,6 +29,7 @@ We provide the following recipes:
|
|||||||
- [yesno][yesno]
|
- [yesno][yesno]
|
||||||
- [LibriSpeech][librispeech]
|
- [LibriSpeech][librispeech]
|
||||||
- [GigaSpeech][gigaspeech]
|
- [GigaSpeech][gigaspeech]
|
||||||
|
- [AMI][ami]
|
||||||
- [Aishell][aishell]
|
- [Aishell][aishell]
|
||||||
- [Aishell2][aishell2]
|
- [Aishell2][aishell2]
|
||||||
- [Aishell4][aishell4]
|
- [Aishell4][aishell4]
|
||||||
@ -37,6 +38,7 @@ We provide the following recipes:
|
|||||||
- [Aidatatang_200zh][aidatatang_200zh]
|
- [Aidatatang_200zh][aidatatang_200zh]
|
||||||
- [WenetSpeech][wenetspeech]
|
- [WenetSpeech][wenetspeech]
|
||||||
- [Alimeeting][alimeeting]
|
- [Alimeeting][alimeeting]
|
||||||
|
- [Switchboard][swbd]
|
||||||
- [TAL_CSASR][tal_csasr]
|
- [TAL_CSASR][tal_csasr]
|
||||||
|
|
||||||
### yesno
|
### yesno
|
||||||
@ -116,11 +118,12 @@ We provide a Colab notebook to run a pre-trained transducer conformer + stateles
|
|||||||
|
|
||||||
#### k2 pruned RNN-T
|
#### k2 pruned RNN-T
|
||||||
|
|
||||||
| Encoder | Params | test-clean | test-other |
|
| Encoder | Params | test-clean | test-other | epochs | devices |
|
||||||
|-----------------|--------|------------|------------|
|
|-----------------|--------|------------|------------|---------|------------|
|
||||||
| zipformer | 65.5M | 2.21 | 4.91 |
|
| zipformer | 65.5M | 2.21 | 4.79 | 50 | 4 32G-V100 |
|
||||||
| zipformer-small | 23.2M | 2.46 | 5.83 |
|
| zipformer-small | 23.2M | 2.42 | 5.73 | 50 | 2 32G-V100 |
|
||||||
| zipformer-large | 148.4M | 2.11 | 4.77 |
|
| zipformer-large | 148.4M | 2.06 | 4.63 | 50 | 4 32G-V100 |
|
||||||
|
| zipformer-large | 148.4M | 2.00 | 4.38 | 174 | 8 80G-A100 |
|
||||||
|
|
||||||
Note: No auxiliary losses are used in the training and no LMs are used
|
Note: No auxiliary losses are used in the training and no LMs are used
|
||||||
in the decoding.
|
in the decoding.
|
||||||
@ -146,8 +149,11 @@ in the decoding.
|
|||||||
|
|
||||||
### GigaSpeech
|
### GigaSpeech
|
||||||
|
|
||||||
We provide two models for this recipe: [Conformer CTC model][GigaSpeech_conformer_ctc]
|
We provide three models for this recipe:
|
||||||
and [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][GigaSpeech_pruned_transducer_stateless2].
|
|
||||||
|
- [Conformer CTC model][GigaSpeech_conformer_ctc]
|
||||||
|
- [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][GigaSpeech_pruned_transducer_stateless2].
|
||||||
|
- [Transducer: Zipformer encoder + Embedding decoder][GigaSpeech_zipformer]
|
||||||
|
|
||||||
#### Conformer CTC
|
#### Conformer CTC
|
||||||
|
|
||||||
@ -163,6 +169,14 @@ and [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned R
|
|||||||
| fast beam search | 10.50 | 10.69 |
|
| fast beam search | 10.50 | 10.69 |
|
||||||
| modified beam search | 10.40 | 10.51 |
|
| modified beam search | 10.40 | 10.51 |
|
||||||
|
|
||||||
|
#### Transducer: Zipformer encoder + Embedding decoder
|
||||||
|
|
||||||
|
| | Dev | Test |
|
||||||
|
|----------------------|-------|-------|
|
||||||
|
| greedy search | 10.31 | 10.50 |
|
||||||
|
| fast beam search | 10.26 | 10.48 |
|
||||||
|
| modified beam search | 10.25 | 10.38 |
|
||||||
|
|
||||||
|
|
||||||
### Aishell
|
### Aishell
|
||||||
|
|
||||||
@ -353,7 +367,7 @@ Once you have trained a model in icefall, you may want to deploy it with C++,
|
|||||||
without Python dependencies.
|
without Python dependencies.
|
||||||
|
|
||||||
Please refer to the documentation
|
Please refer to the documentation
|
||||||
<https://icefall.readthedocs.io/en/latest/recipes/librispeech/conformer_ctc.html#deployment-with-c>
|
<https://icefall.readthedocs.io/en/latest/recipes/Non-streaming-ASR/librispeech/conformer_ctc.html#deployment-with-c>
|
||||||
for how to do this.
|
for how to do this.
|
||||||
|
|
||||||
We also provide a Colab notebook, showing you how to run a torch scripted model in [k2][k2] with C++.
|
We also provide a Colab notebook, showing you how to run a torch scripted model in [k2][k2] with C++.
|
||||||
@ -376,6 +390,7 @@ Please see: [
|
|
||||||
- end in a newline and only a newline
|
|
||||||
- contain sorted `imports` (checked by [isort][isort])
|
|
||||||
|
|
||||||
These hooks are disabled by default. Please use the following commands to enable them:
|
- Ensuring there are no trailing spaces.
|
||||||
|
- Formatting code with [black](https://github.com/psf/black).
|
||||||
|
- Checking compliance with PEP8 using [flake8](https://flake8.pycqa.org/).
|
||||||
|
- Verifying that files end with a newline character (and only a newline).
|
||||||
|
- Sorting imports using [isort](https://pycqa.github.io/isort/).
|
||||||
|
|
||||||
```bash
|
Please note that these hooks are disabled by default. To enable them, follow these steps:
|
||||||
pip install pre-commit # run it only once
|
|
||||||
pre-commit install # run it only once, it will install all hooks
|
|
||||||
|
|
||||||
# modify some files
|
### Installation (Run only once)
|
||||||
git add <some files>
|
|
||||||
git commit # It runs all hooks automatically.
|
|
||||||
|
|
||||||
# If all hooks run successfully, you can write the commit message now. Done!
|
1. Install the `pre-commit` package using pip:
|
||||||
#
|
```bash
|
||||||
# If any hook failed, your commit was not successful.
|
pip install pre-commit
|
||||||
# Please read the error messages and make changes accordingly.
|
```
|
||||||
# And rerun
|
1. Install the Git hooks using:
|
||||||
|
```bash
|
||||||
|
pre-commit install
|
||||||
|
```
|
||||||
|
### Making a Commit
|
||||||
|
Once you have enabled the pre-commit hooks, follow these steps when making a commit:
|
||||||
|
1. Make your changes to the codebase.
|
||||||
|
2. Stage your changes by using git add for the files you modified.
|
||||||
|
3. Commit your changes using git commit. The pre-commit hooks will run automatically at this point.
|
||||||
|
4. If all hooks run successfully, you can write your commit message, and your changes will be successfully committed.
|
||||||
|
5. If any hook fails, your commit will not be successful. Please read and follow the error messages provided, make the necessary changes, and then re-run git add and git commit.
|
||||||
|
|
||||||
git add <some files>
|
### Your Contribution
|
||||||
git commit
|
Your contributions are valuable to us, and by following these guidelines, you help maintain code consistency and quality in our project. We appreciate your dedication to ensuring high-quality code. If you have questions or need assistance, feel free to reach out to us. Thank you for being part of our open-source community!
|
||||||
```
|
|
||||||
|
|
||||||
[git]: https://git-scm.com/book/en/v2/Customizing-Git-Git-Hooks
|
|
||||||
[flake8]: https://github.com/PyCQA/flake8
|
|
||||||
[PEP8]: https://www.python.org/dev/peps/pep-0008/
|
|
||||||
[black]: https://github.com/psf/black
|
|
||||||
[hooks]: https://github.com/pre-commit/pre-commit-hooks
|
|
||||||
[pre-commit]: https://github.com/pre-commit/pre-commit
|
|
||||||
[isort]: https://github.com/PyCQA/isort
|
|
||||||
|
@ -4,8 +4,9 @@ ENV LC_ALL C.UTF-8
|
|||||||
|
|
||||||
ARG DEBIAN_FRONTEND=noninteractive
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
ARG K2_VERSION="1.24.3.dev20230725+cuda11.3.torch1.12.1"
|
# python 3.7
|
||||||
ARG KALDIFEAT_VERSION="1.25.0.dev20230726+cuda11.3.torch1.12.1"
|
ARG K2_VERSION="1.24.4.dev20230725+cuda11.3.torch1.12.1"
|
||||||
|
ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.3.torch1.12.1"
|
||||||
ARG TORCHAUDIO_VERSION="0.12.1+cu113"
|
ARG TORCHAUDIO_VERSION="0.12.1+cu113"
|
||||||
|
|
||||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||||
@ -43,7 +44,6 @@ RUN pip install --no-cache-dir \
|
|||||||
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
|
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
|
||||||
git+https://github.com/lhotse-speech/lhotse \
|
git+https://github.com/lhotse-speech/lhotse \
|
||||||
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
|
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
|
||||||
\
|
|
||||||
kaldi_native_io \
|
kaldi_native_io \
|
||||||
kaldialign \
|
kaldialign \
|
||||||
kaldifst \
|
kaldifst \
|
||||||
|
@ -4,8 +4,9 @@ ENV LC_ALL C.UTF-8
|
|||||||
|
|
||||||
ARG DEBIAN_FRONTEND=noninteractive
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
ARG K2_VERSION="1.24.3.dev20230725+cuda11.6.torch1.13.0"
|
# python 3.9
|
||||||
ARG KALDIFEAT_VERSION="1.25.0.dev20230726+cuda11.6.torch1.13.0"
|
ARG K2_VERSION="1.24.4.dev20231021+cuda11.6.torch1.13.0"
|
||||||
|
ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.6.torch1.13.0"
|
||||||
ARG TORCHAUDIO_VERSION="0.13.0+cu116"
|
ARG TORCHAUDIO_VERSION="0.13.0+cu116"
|
||||||
|
|
||||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||||
@ -43,7 +44,6 @@ RUN pip install --no-cache-dir \
|
|||||||
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
|
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
|
||||||
git+https://github.com/lhotse-speech/lhotse \
|
git+https://github.com/lhotse-speech/lhotse \
|
||||||
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
|
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
|
||||||
\
|
|
||||||
kaldi_native_io \
|
kaldi_native_io \
|
||||||
kaldialign \
|
kaldialign \
|
||||||
kaldifst \
|
kaldifst \
|
||||||
|
@ -4,8 +4,9 @@ ENV LC_ALL C.UTF-8
|
|||||||
|
|
||||||
ARG DEBIAN_FRONTEND=noninteractive
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
# python 3.7
|
||||||
ARG K2_VERSION="1.24.3.dev20230726+cuda10.2.torch1.9.0"
|
ARG K2_VERSION="1.24.3.dev20230726+cuda10.2.torch1.9.0"
|
||||||
ARG KALDIFEAT_VERSION="1.25.0.dev20230726+cuda10.2.torch1.9.0"
|
ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda10.2.torch1.9.0"
|
||||||
ARG TORCHAUDIO_VERSION="0.9.0"
|
ARG TORCHAUDIO_VERSION="0.9.0"
|
||||||
|
|
||||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||||
@ -57,7 +58,6 @@ RUN pip uninstall -y tqdm && \
|
|||||||
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
|
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
|
||||||
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
|
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
|
||||||
git+https://github.com/lhotse-speech/lhotse \
|
git+https://github.com/lhotse-speech/lhotse \
|
||||||
\
|
|
||||||
kaldi_native_io \
|
kaldi_native_io \
|
||||||
kaldialign \
|
kaldialign \
|
||||||
kaldifst \
|
kaldifst \
|
||||||
|
@ -4,8 +4,9 @@ ENV LC_ALL C.UTF-8
|
|||||||
|
|
||||||
ARG DEBIAN_FRONTEND=noninteractive
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
ARG K2_VERSION="1.24.3.dev20230718+cuda11.7.torch2.0.0"
|
# python 3.10
|
||||||
ARG KALDIFEAT_VERSION="1.25.0.dev20230726+cuda11.7.torch2.0.0"
|
ARG K2_VERSION="1.24.4.dev20231021+cuda11.7.torch2.0.0"
|
||||||
|
ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.7.torch2.0.0"
|
||||||
ARG TORCHAUDIO_VERSION="2.0.0+cu117"
|
ARG TORCHAUDIO_VERSION="2.0.0+cu117"
|
||||||
|
|
||||||
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||||
@ -43,7 +44,6 @@ RUN pip install --no-cache-dir \
|
|||||||
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
|
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
|
||||||
git+https://github.com/lhotse-speech/lhotse \
|
git+https://github.com/lhotse-speech/lhotse \
|
||||||
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
|
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
|
||||||
\
|
|
||||||
kaldi_native_io \
|
kaldi_native_io \
|
||||||
kaldialign \
|
kaldialign \
|
||||||
kaldifst \
|
kaldifst \
|
||||||
|
70
docker/torch2.1.0-cuda11.8.dockerfile
Normal file
70
docker/torch2.1.0-cuda11.8.dockerfile
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
FROM pytorch/pytorch:2.1.0-cuda11.8-cudnn8-devel
|
||||||
|
|
||||||
|
ENV LC_ALL C.UTF-8
|
||||||
|
|
||||||
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
# python 3.10
|
||||||
|
ARG K2_VERSION="1.24.4.dev20231021+cuda11.8.torch2.1.0"
|
||||||
|
ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.8.torch2.1.0"
|
||||||
|
ARG TORCHAUDIO_VERSION="2.1.0+cu118"
|
||||||
|
|
||||||
|
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||||
|
LABEL k2_version=${K2_VERSION}
|
||||||
|
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
|
||||||
|
LABEL github_repo="https://github.com/k2-fsa/icefall"
|
||||||
|
|
||||||
|
RUN apt-get update && \
|
||||||
|
apt-get install -y --no-install-recommends \
|
||||||
|
curl \
|
||||||
|
vim \
|
||||||
|
libssl-dev \
|
||||||
|
autoconf \
|
||||||
|
automake \
|
||||||
|
bzip2 \
|
||||||
|
ca-certificates \
|
||||||
|
ffmpeg \
|
||||||
|
g++ \
|
||||||
|
gfortran \
|
||||||
|
git \
|
||||||
|
libtool \
|
||||||
|
make \
|
||||||
|
patch \
|
||||||
|
sox \
|
||||||
|
subversion \
|
||||||
|
unzip \
|
||||||
|
valgrind \
|
||||||
|
wget \
|
||||||
|
zlib1g-dev \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
RUN pip install --no-cache-dir \
|
||||||
|
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
|
||||||
|
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
|
||||||
|
git+https://github.com/lhotse-speech/lhotse \
|
||||||
|
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
|
||||||
|
kaldi_native_io \
|
||||||
|
kaldialign \
|
||||||
|
kaldifst \
|
||||||
|
kaldilm \
|
||||||
|
sentencepiece>=0.1.96 \
|
||||||
|
tensorboard \
|
||||||
|
typeguard \
|
||||||
|
dill \
|
||||||
|
onnx \
|
||||||
|
onnxruntime \
|
||||||
|
onnxmltools \
|
||||||
|
multi_quantization \
|
||||||
|
typeguard \
|
||||||
|
numpy \
|
||||||
|
pytest \
|
||||||
|
graphviz
|
||||||
|
|
||||||
|
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
|
||||||
|
cd /workspace/icefall && \
|
||||||
|
pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
|
||||||
|
|
||||||
|
WORKDIR /workspace/icefall
|
70
docker/torch2.1.0-cuda12.1.dockerfile
Normal file
70
docker/torch2.1.0-cuda12.1.dockerfile
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-devel
|
||||||
|
|
||||||
|
ENV LC_ALL C.UTF-8
|
||||||
|
|
||||||
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
# python 3.10
|
||||||
|
ARG K2_VERSION="1.24.4.dev20231021+cuda12.1.torch2.1.0"
|
||||||
|
ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda12.1.torch2.1.0"
|
||||||
|
ARG TORCHAUDIO_VERSION="2.1.0+cu121"
|
||||||
|
|
||||||
|
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
|
||||||
|
LABEL k2_version=${K2_VERSION}
|
||||||
|
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
|
||||||
|
LABEL github_repo="https://github.com/k2-fsa/icefall"
|
||||||
|
|
||||||
|
RUN apt-get update && \
|
||||||
|
apt-get install -y --no-install-recommends \
|
||||||
|
curl \
|
||||||
|
vim \
|
||||||
|
libssl-dev \
|
||||||
|
autoconf \
|
||||||
|
automake \
|
||||||
|
bzip2 \
|
||||||
|
ca-certificates \
|
||||||
|
ffmpeg \
|
||||||
|
g++ \
|
||||||
|
gfortran \
|
||||||
|
git \
|
||||||
|
libtool \
|
||||||
|
make \
|
||||||
|
patch \
|
||||||
|
sox \
|
||||||
|
subversion \
|
||||||
|
unzip \
|
||||||
|
valgrind \
|
||||||
|
wget \
|
||||||
|
zlib1g-dev \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
RUN pip install --no-cache-dir \
|
||||||
|
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
|
||||||
|
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
|
||||||
|
git+https://github.com/lhotse-speech/lhotse \
|
||||||
|
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
|
||||||
|
kaldi_native_io \
|
||||||
|
kaldialign \
|
||||||
|
kaldifst \
|
||||||
|
kaldilm \
|
||||||
|
sentencepiece>=0.1.96 \
|
||||||
|
tensorboard \
|
||||||
|
typeguard \
|
||||||
|
dill \
|
||||||
|
onnx \
|
||||||
|
onnxruntime \
|
||||||
|
onnxmltools \
|
||||||
|
multi_quantization \
|
||||||
|
typeguard \
|
||||||
|
numpy \
|
||||||
|
pytest \
|
||||||
|
graphviz
|
||||||
|
|
||||||
|
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
|
||||||
|
cd /workspace/icefall && \
|
||||||
|
pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
|
||||||
|
|
||||||
|
WORKDIR /workspace/icefall
|
@ -38,7 +38,7 @@ Please fix any issues reported by the check tools.
|
|||||||
.. HINT::
|
.. HINT::
|
||||||
|
|
||||||
Some of the check tools, i.e., ``black`` and ``isort`` will modify
|
Some of the check tools, i.e., ``black`` and ``isort`` will modify
|
||||||
the files to be commited **in-place**. So please run ``git status``
|
the files to be committed **in-place**. So please run ``git status``
|
||||||
after failure to see which file has been modified by the tools
|
after failure to see which file has been modified by the tools
|
||||||
before you make any further changes.
|
before you make any further changes.
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ How to create a recipe
|
|||||||
|
|
||||||
.. HINT::
|
.. HINT::
|
||||||
|
|
||||||
Please read :ref:`follow the code style` to adjust your code sytle.
|
Please read :ref:`follow the code style` to adjust your code style.
|
||||||
|
|
||||||
.. CAUTION::
|
.. CAUTION::
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ during decoding for transducer model:
|
|||||||
\lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) -
|
\lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) -
|
||||||
\lambda_2 \log p_{\text{bi-gram}}\left(y_u|\mathit{x},y_{1:u-1}\right)
|
\lambda_2 \log p_{\text{bi-gram}}\left(y_u|\mathit{x},y_{1:u-1}\right)
|
||||||
|
|
||||||
In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Comared to DR,
|
In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Compared to DR,
|
||||||
the only difference lies in the choice of source domain LM. According to the original `paper <https://arxiv.org/abs/2203.16776>`_,
|
the only difference lies in the choice of source domain LM. According to the original `paper <https://arxiv.org/abs/2203.16776>`_,
|
||||||
LODR achieves similar performance compared DR in both intra-domain and cross-domain settings.
|
LODR achieves similar performance compared DR in both intra-domain and cross-domain settings.
|
||||||
As a bi-gram is much faster to evaluate, LODR is usually much faster.
|
As a bi-gram is much faster to evaluate, LODR is usually much faster.
|
||||||
|
@ -30,6 +30,8 @@ which will give you something like below:
|
|||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
|
"torch2.1.0-cuda12.1"
|
||||||
|
"torch2.1.0-cuda11.8"
|
||||||
"torch2.0.0-cuda11.7"
|
"torch2.0.0-cuda11.7"
|
||||||
"torch1.12.1-cuda11.3"
|
"torch1.12.1-cuda11.3"
|
||||||
"torch1.9.0-cuda10.2"
|
"torch1.9.0-cuda10.2"
|
||||||
|
@ -125,7 +125,7 @@ Python code. We have also set up ``PATH`` so that you can use
|
|||||||
.. caution::
|
.. caution::
|
||||||
|
|
||||||
Please don't use `<https://github.com/tencent/ncnn>`_.
|
Please don't use `<https://github.com/tencent/ncnn>`_.
|
||||||
We have made some modifications to the offical `ncnn`_.
|
We have made some modifications to the official `ncnn`_.
|
||||||
|
|
||||||
We will synchronize `<https://github.com/csukuangfj/ncnn>`_ periodically
|
We will synchronize `<https://github.com/csukuangfj/ncnn>`_ periodically
|
||||||
with the official one.
|
with the official one.
|
||||||
|
@ -67,7 +67,7 @@ To run stage 2 to stage 5, use:
|
|||||||
.. HINT::
|
.. HINT::
|
||||||
|
|
||||||
A 3-gram language model will be downloaded from huggingface, we assume you have
|
A 3-gram language model will be downloaded from huggingface, we assume you have
|
||||||
intalled and initialized ``git-lfs``. If not, you could install ``git-lfs`` by
|
installed and initialized ``git-lfs``. If not, you could install ``git-lfs`` by
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
|
@ -67,7 +67,7 @@ To run stage 2 to stage 5, use:
|
|||||||
.. HINT::
|
.. HINT::
|
||||||
|
|
||||||
A 3-gram language model will be downloaded from huggingface, we assume you have
|
A 3-gram language model will be downloaded from huggingface, we assume you have
|
||||||
intalled and initialized ``git-lfs``. If not, you could install ``git-lfs`` by
|
installed and initialized ``git-lfs``. If not, you could install ``git-lfs`` by
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
|
@ -418,7 +418,7 @@ The following shows two examples (for two types of checkpoints):
|
|||||||
|
|
||||||
- ``beam_search`` : It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf and
|
- ``beam_search`` : It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf and
|
||||||
`espnet/nets/beam_search_transducer.py <https://github.com/espnet/espnet/blob/master/espnet/nets/beam_search_transducer.py#L247>`_
|
`espnet/nets/beam_search_transducer.py <https://github.com/espnet/espnet/blob/master/espnet/nets/beam_search_transducer.py#L247>`_
|
||||||
is used as a reference. Basicly, it keeps topk states for each frame, and expands the kept states with their own contexts to
|
is used as a reference. Basically, it keeps topk states for each frame, and expands the kept states with their own contexts to
|
||||||
next frame.
|
next frame.
|
||||||
|
|
||||||
- ``modified_beam_search`` : It implements the same algorithm as ``beam_search`` above, but it
|
- ``modified_beam_search`` : It implements the same algorithm as ``beam_search`` above, but it
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
.. _train_nnlm:
|
.. _train_nnlm:
|
||||||
|
|
||||||
Train an RNN langugage model
|
Train an RNN language model
|
||||||
======================================
|
======================================
|
||||||
|
|
||||||
If you have enough text data, you can train a neural network language model (NNLM) to improve
|
If you have enough text data, you can train a neural network language model (NNLM) to improve
|
||||||
|
@ -32,7 +32,7 @@ In icefall, we implement the streaming conformer the way just like what `WeNet <
|
|||||||
.. HINT::
|
.. HINT::
|
||||||
If you want to modify a non-streaming conformer recipe to support both streaming and non-streaming, please refer
|
If you want to modify a non-streaming conformer recipe to support both streaming and non-streaming, please refer
|
||||||
to `this pull request <https://github.com/k2-fsa/icefall/pull/454>`_. After adding the code needed by streaming training,
|
to `this pull request <https://github.com/k2-fsa/icefall/pull/454>`_. After adding the code needed by streaming training,
|
||||||
you have to re-train it with the extra arguments metioned in the docs above to get a streaming model.
|
you have to re-train it with the extra arguments mentioned in the docs above to get a streaming model.
|
||||||
|
|
||||||
|
|
||||||
Streaming Emformer
|
Streaming Emformer
|
||||||
|
@ -584,7 +584,7 @@ The following shows two examples (for the two types of checkpoints):
|
|||||||
|
|
||||||
- ``beam_search`` : It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf and
|
- ``beam_search`` : It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf and
|
||||||
`espnet/nets/beam_search_transducer.py <https://github.com/espnet/espnet/blob/master/espnet/nets/beam_search_transducer.py#L247>`_
|
`espnet/nets/beam_search_transducer.py <https://github.com/espnet/espnet/blob/master/espnet/nets/beam_search_transducer.py#L247>`_
|
||||||
is used as a reference. Basicly, it keeps topk states for each frame, and expands the kept states with their own contexts to
|
is used as a reference. Basically, it keeps topk states for each frame, and expands the kept states with their own contexts to
|
||||||
next frame.
|
next frame.
|
||||||
|
|
||||||
- ``modified_beam_search`` : It implements the same algorithm as ``beam_search`` above, but it
|
- ``modified_beam_search`` : It implements the same algorithm as ``beam_search`` above, but it
|
||||||
@ -648,7 +648,7 @@ command to extract ``model.state_dict()``.
|
|||||||
.. caution::
|
.. caution::
|
||||||
|
|
||||||
``--streaming-model`` and ``--causal-convolution`` require to be True to export
|
``--streaming-model`` and ``--causal-convolution`` require to be True to export
|
||||||
a streaming mdoel.
|
a streaming model.
|
||||||
|
|
||||||
It will generate a file ``./pruned_transducer_stateless4/exp/pretrained.pt``.
|
It will generate a file ``./pruned_transducer_stateless4/exp/pretrained.pt``.
|
||||||
|
|
||||||
@ -697,7 +697,7 @@ Export model using ``torch.jit.script()``
|
|||||||
.. caution::
|
.. caution::
|
||||||
|
|
||||||
``--streaming-model`` and ``--causal-convolution`` require to be True to export
|
``--streaming-model`` and ``--causal-convolution`` require to be True to export
|
||||||
a streaming mdoel.
|
a streaming model.
|
||||||
|
|
||||||
It will generate a file ``cpu_jit.pt`` in the given ``exp_dir``. You can later
|
It will generate a file ``cpu_jit.pt`` in the given ``exp_dir``. You can later
|
||||||
load it by ``torch.jit.load("cpu_jit.pt")``.
|
load it by ``torch.jit.load("cpu_jit.pt")``.
|
||||||
|
8
docs/source/recipes/TTS/index.rst
Normal file
8
docs/source/recipes/TTS/index.rst
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
TTS
|
||||||
|
======
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 2
|
||||||
|
|
||||||
|
ljspeech/vits
|
||||||
|
vctk/vits
|
123
docs/source/recipes/TTS/ljspeech/vits.rst
Normal file
123
docs/source/recipes/TTS/ljspeech/vits.rst
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
VITS
|
||||||
|
===============
|
||||||
|
|
||||||
|
This tutorial shows you how to train an VITS model
|
||||||
|
with the `LJSpeech <https://keithito.com/LJ-Speech-Dataset/>`_ dataset.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
TTS related recipes require packages in ``requirements-tts.txt``.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech <https://arxiv.org/pdf/2106.06103.pdf>`_
|
||||||
|
|
||||||
|
|
||||||
|
Data preparation
|
||||||
|
----------------
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ cd egs/ljspeech/TTS
|
||||||
|
$ ./prepare.sh
|
||||||
|
|
||||||
|
To run stage 1 to stage 5, use
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ ./prepare.sh --stage 1 --stop_stage 5
|
||||||
|
|
||||||
|
|
||||||
|
Build Monotonic Alignment Search
|
||||||
|
--------------------------------
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ ./prepare.sh --stage -1 --stop_stage -1
|
||||||
|
|
||||||
|
or
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ cd vits/monotonic_align
|
||||||
|
$ python setup.py build_ext --inplace
|
||||||
|
$ cd ../../
|
||||||
|
|
||||||
|
|
||||||
|
Training
|
||||||
|
--------
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
$ ./vits/train.py \
|
||||||
|
--world-size 4 \
|
||||||
|
--num-epochs 1000 \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--exp-dir vits/exp \
|
||||||
|
--tokens data/tokens.txt
|
||||||
|
--max-duration 500
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
You can adjust the hyper-parameters to control the size of the VITS model and
|
||||||
|
the training configurations. For more details, please run ``./vits/train.py --help``.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
The training can take a long time (usually a couple of days).
|
||||||
|
|
||||||
|
Training logs, checkpoints and tensorboard logs are saved in ``vits/exp``.
|
||||||
|
|
||||||
|
|
||||||
|
Inference
|
||||||
|
---------
|
||||||
|
|
||||||
|
The inference part uses checkpoints saved by the training part, so you have to run the
|
||||||
|
training part first. It will save the ground-truth and generated wavs to the directory
|
||||||
|
``vits/exp/infer/epoch-*/wav``, e.g., ``vits/exp/infer/epoch-1000/wav``.
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ export CUDA_VISIBLE_DEVICES="0"
|
||||||
|
$ ./vits/infer.py \
|
||||||
|
--epoch 1000 \
|
||||||
|
--exp-dir vits/exp \
|
||||||
|
--tokens data/tokens.txt \
|
||||||
|
--max-duration 500
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
For more details, please run ``./vits/infer.py --help``.
|
||||||
|
|
||||||
|
|
||||||
|
Export models
|
||||||
|
-------------
|
||||||
|
|
||||||
|
Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``:
|
||||||
|
``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``.
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ ./vits/export-onnx.py \
|
||||||
|
--epoch 1000 \
|
||||||
|
--exp-dir vits/exp \
|
||||||
|
--tokens data/tokens.txt
|
||||||
|
|
||||||
|
You can test the exported ONNX model with:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ ./vits/test_onnx.py \
|
||||||
|
--model-filename vits/exp/vits-epoch-1000.onnx \
|
||||||
|
--tokens data/tokens.txt
|
||||||
|
|
||||||
|
|
||||||
|
Download pretrained models
|
||||||
|
--------------------------
|
||||||
|
|
||||||
|
If you don't want to train from scratch, you can download the pretrained models
|
||||||
|
by visiting the following link:
|
||||||
|
|
||||||
|
- `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2023-11-29>`_
|
125
docs/source/recipes/TTS/vctk/vits.rst
Normal file
125
docs/source/recipes/TTS/vctk/vits.rst
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
VITS
|
||||||
|
===============
|
||||||
|
|
||||||
|
This tutorial shows you how to train an VITS model
|
||||||
|
with the `VCTK <https://datashare.ed.ac.uk/handle/10283/3443>`_ dataset.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
TTS related recipes require packages in ``requirements-tts.txt``.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech <https://arxiv.org/pdf/2106.06103.pdf>`_
|
||||||
|
|
||||||
|
|
||||||
|
Data preparation
|
||||||
|
----------------
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ cd egs/vctk/TTS
|
||||||
|
$ ./prepare.sh
|
||||||
|
|
||||||
|
To run stage 1 to stage 6, use
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ ./prepare.sh --stage 1 --stop_stage 6
|
||||||
|
|
||||||
|
|
||||||
|
Build Monotonic Alignment Search
|
||||||
|
--------------------------------
|
||||||
|
|
||||||
|
To build the monotonic alignment search, use the following commands:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ ./prepare.sh --stage -1 --stop_stage -1
|
||||||
|
|
||||||
|
or
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ cd vits/monotonic_align
|
||||||
|
$ python setup.py build_ext --inplace
|
||||||
|
$ cd ../../
|
||||||
|
|
||||||
|
|
||||||
|
Training
|
||||||
|
--------
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
$ ./vits/train.py \
|
||||||
|
--world-size 4 \
|
||||||
|
--num-epochs 1000 \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--exp-dir vits/exp \
|
||||||
|
--tokens data/tokens.txt
|
||||||
|
--max-duration 350
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
You can adjust the hyper-parameters to control the size of the VITS model and
|
||||||
|
the training configurations. For more details, please run ``./vits/train.py --help``.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
The training can take a long time (usually a couple of days).
|
||||||
|
|
||||||
|
Training logs, checkpoints and tensorboard logs are saved in ``vits/exp``.
|
||||||
|
|
||||||
|
|
||||||
|
Inference
|
||||||
|
---------
|
||||||
|
|
||||||
|
The inference part uses checkpoints saved by the training part, so you have to run the
|
||||||
|
training part first. It will save the ground-truth and generated wavs to the directory
|
||||||
|
``vits/exp/infer/epoch-*/wav``, e.g., ``vits/exp/infer/epoch-1000/wav``.
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ export CUDA_VISIBLE_DEVICES="0"
|
||||||
|
$ ./vits/infer.py \
|
||||||
|
--epoch 1000 \
|
||||||
|
--exp-dir vits/exp \
|
||||||
|
--tokens data/tokens.txt \
|
||||||
|
--max-duration 500
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
For more details, please run ``./vits/infer.py --help``.
|
||||||
|
|
||||||
|
|
||||||
|
Export models
|
||||||
|
-------------
|
||||||
|
|
||||||
|
Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``:
|
||||||
|
``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``.
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ ./vits/export-onnx.py \
|
||||||
|
--epoch 1000 \
|
||||||
|
--exp-dir vits/exp \
|
||||||
|
--tokens data/tokens.txt
|
||||||
|
|
||||||
|
You can test the exported ONNX model with:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ ./vits/test_onnx.py \
|
||||||
|
--model-filename vits/exp/vits-epoch-1000.onnx \
|
||||||
|
--tokens data/tokens.txt
|
||||||
|
|
||||||
|
|
||||||
|
Download pretrained models
|
||||||
|
--------------------------
|
||||||
|
|
||||||
|
If you don't want to train from scratch, you can download the pretrained models
|
||||||
|
by visiting the following link:
|
||||||
|
|
||||||
|
- `<https://huggingface.co/zrjin/icefall-tts-vctk-vits-2023-12-05>`_
|
@ -2,7 +2,7 @@ Recipes
|
|||||||
=======
|
=======
|
||||||
|
|
||||||
This page contains various recipes in ``icefall``.
|
This page contains various recipes in ``icefall``.
|
||||||
Currently, only speech recognition recipes are provided.
|
Currently, we provide recipes for speech recognition, language model, and speech synthesis.
|
||||||
|
|
||||||
We may add recipes for other tasks as well in the future.
|
We may add recipes for other tasks as well in the future.
|
||||||
|
|
||||||
@ -16,3 +16,4 @@ We may add recipes for other tasks as well in the future.
|
|||||||
Non-streaming-ASR/index
|
Non-streaming-ASR/index
|
||||||
Streaming-ASR/index
|
Streaming-ASR/index
|
||||||
RNN-LM/index
|
RNN-LM/index
|
||||||
|
TTS/index
|
||||||
|
@ -7,6 +7,8 @@ set -eou pipefail
|
|||||||
|
|
||||||
stage=-1
|
stage=-1
|
||||||
stop_stage=100
|
stop_stage=100
|
||||||
|
perturb_speed=true
|
||||||
|
|
||||||
|
|
||||||
# We assume dl_dir (download dir) contains the following
|
# We assume dl_dir (download dir) contains the following
|
||||||
# directories and files. If not, they will be downloaded
|
# directories and files. If not, they will be downloaded
|
||||||
@ -77,7 +79,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
|||||||
log "Stage 4: Compute fbank for aidatatang_200zh"
|
log "Stage 4: Compute fbank for aidatatang_200zh"
|
||||||
if [ ! -f data/fbank/.aidatatang_200zh.done ]; then
|
if [ ! -f data/fbank/.aidatatang_200zh.done ]; then
|
||||||
mkdir -p data/fbank
|
mkdir -p data/fbank
|
||||||
./local/compute_fbank_aidatatang_200zh.py --perturb-speed True
|
./local/compute_fbank_aidatatang_200zh.py --perturb-speed ${perturb_speed}
|
||||||
touch data/fbank/.aidatatang_200zh.done
|
touch data/fbank/.aidatatang_200zh.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
@ -102,7 +102,7 @@ class Aidatatang_200zhAsrDataModule:
|
|||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--bucketing-sampler",
|
"--bucketing-sampler",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=False,
|
||||||
help="When enabled, the batches will come from buckets of "
|
help="When enabled, the batches will come from buckets of "
|
||||||
"similar duration (saves padding frames).",
|
"similar duration (saves padding frames).",
|
||||||
)
|
)
|
||||||
@ -211,7 +211,7 @@ class Aidatatang_200zhAsrDataModule:
|
|||||||
if self.args.enable_musan:
|
if self.args.enable_musan:
|
||||||
logging.info("Enable MUSAN")
|
logging.info("Enable MUSAN")
|
||||||
transforms.append(
|
transforms.append(
|
||||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Disable MUSAN")
|
logging.info("Disable MUSAN")
|
||||||
@ -289,6 +289,7 @@ class Aidatatang_200zhAsrDataModule:
|
|||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
num_buckets=self.args.num_buckets,
|
num_buckets=self.args.num_buckets,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
|
buffer_size=50000,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SimpleCutSampler.")
|
logging.info("Using SimpleCutSampler.")
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
|
|
||||||
# Introduction
|
# Introduction
|
||||||
|
|
||||||
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/Non-streaming-ASR/aishell/index.html>
|
Please refer to <https://k2-fsa.github.io/icefall/recipes/Non-streaming-ASR/aishell/index.html> for how to run models in this recipe.
|
||||||
for how to run models in this recipe.
|
|
||||||
|
|
||||||
|
Aishell is an open-source Chinese Mandarin speech corpus published by Beijing Shell Shell Technology Co., Ltd.
|
||||||
|
400 people from different accent areas in China are invited to participate in the recording, which is conducted in a quiet indoor environment using high fidelity microphone and downsampled to 16kHz. The manual transcription accuracy is above 95%, through professional speech annotation and strict quality inspection. The data is free for academic use. We hope to provide moderate amount of data for new researchers in the field of speech recognition.
|
||||||
|
|
||||||
|
(From [Open Speech and Language Resources](https://www.openslr.org/33/))
|
||||||
|
|
||||||
# Transducers
|
# Transducers
|
||||||
|
|
||||||
|
@ -1,6 +1,212 @@
|
|||||||
## Results
|
## Results
|
||||||
|
|
||||||
### Aishell training result(Stateless Transducer)
|
### Aishell training result (Stateless Transducer)
|
||||||
|
|
||||||
|
#### Zipformer (Non-streaming)
|
||||||
|
|
||||||
|
[./zipformer](./zipformer)
|
||||||
|
|
||||||
|
It's reworked Zipformer with Pruned RNNT loss.
|
||||||
|
**Caution**: It uses `--context-size=1`.
|
||||||
|
|
||||||
|
##### normal-scaled model, number of model parameters: 73412551, i.e., 73.41 M
|
||||||
|
|
||||||
|
| | test | dev | comment |
|
||||||
|
|------------------------|------|------|-----------------------------------------|
|
||||||
|
| greedy search | 4.67 | 4.37 | --epoch 55 --avg 17 |
|
||||||
|
| modified beam search | 4.40 | 4.13 | --epoch 55 --avg 17 |
|
||||||
|
| fast beam search | 4.60 | 4.31 | --epoch 55 --avg 17 |
|
||||||
|
|
||||||
|
Command for training is:
|
||||||
|
```bash
|
||||||
|
./prepare.sh
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1"
|
||||||
|
|
||||||
|
./zipformer/train.py \
|
||||||
|
--world-size 2 \
|
||||||
|
--num-epochs 60 \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--context-size 1 \
|
||||||
|
--enable-musan 0 \
|
||||||
|
--exp-dir zipformer/exp \
|
||||||
|
--max-duration 1000 \
|
||||||
|
--enable-musan 0 \
|
||||||
|
--base-lr 0.045 \
|
||||||
|
--lr-batches 7500 \
|
||||||
|
--lr-epochs 18 \
|
||||||
|
--spec-aug-time-warp-factor 20
|
||||||
|
```
|
||||||
|
|
||||||
|
Command for decoding is:
|
||||||
|
```bash
|
||||||
|
for m in greedy_search modified_beam_search fast_beam_search ; do
|
||||||
|
./zipformer/decode.py \
|
||||||
|
--epoch 55 \
|
||||||
|
--avg 17 \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--lang-dir data/lang_char \
|
||||||
|
--context-size 1 \
|
||||||
|
--decoding-method $m
|
||||||
|
done
|
||||||
|
```
|
||||||
|
Pretrained models, training logs, decoding logs, tensorboard and decoding results
|
||||||
|
are available at
|
||||||
|
<https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-2023-10-24>
|
||||||
|
|
||||||
|
|
||||||
|
##### small-scaled model, number of model parameters: 30167139, i.e., 30.17 M
|
||||||
|
|
||||||
|
| | test | dev | comment |
|
||||||
|
|------------------------|------|------|-----------------------------------------|
|
||||||
|
| greedy search | 4.97 | 4.67 | --epoch 55 --avg 21 |
|
||||||
|
| modified beam search | 4.67 | 4.40 | --epoch 55 --avg 21 |
|
||||||
|
| fast beam search | 4.85 | 4.61 | --epoch 55 --avg 21 |
|
||||||
|
|
||||||
|
Command for training is:
|
||||||
|
```bash
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1"
|
||||||
|
|
||||||
|
./zipformer/train.py \
|
||||||
|
--world-size 2 \
|
||||||
|
--num-epochs 60 \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--context-size 1 \
|
||||||
|
--exp-dir zipformer/exp-small \
|
||||||
|
--enable-musan 0 \
|
||||||
|
--base-lr 0.045 \
|
||||||
|
--lr-batches 7500 \
|
||||||
|
--lr-epochs 18 \
|
||||||
|
--spec-aug-time-warp-factor 20 \
|
||||||
|
--num-encoder-layers 2,2,2,2,2,2 \
|
||||||
|
--feedforward-dim 512,768,768,768,768,768 \
|
||||||
|
--encoder-dim 192,256,256,256,256,256 \
|
||||||
|
--encoder-unmasked-dim 192,192,192,192,192,192 \
|
||||||
|
--max-duration 1200
|
||||||
|
```
|
||||||
|
|
||||||
|
Command for decoding is:
|
||||||
|
```bash
|
||||||
|
for m in greedy_search modified_beam_search fast_beam_search ; do
|
||||||
|
./zipformer/decode.py \
|
||||||
|
--epoch 55 \
|
||||||
|
--avg 21 \
|
||||||
|
--exp-dir ./zipformer/exp-small \
|
||||||
|
--lang-dir data/lang_char \
|
||||||
|
--context-size 1 \
|
||||||
|
--decoding-method $m \
|
||||||
|
--num-encoder-layers 2,2,2,2,2,2 \
|
||||||
|
--feedforward-dim 512,768,768,768,768,768 \
|
||||||
|
--encoder-dim 192,256,256,256,256,256 \
|
||||||
|
--encoder-unmasked-dim 192,192,192,192,192,192
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
Pretrained models, training logs, decoding logs, tensorboard and decoding results
|
||||||
|
are available at
|
||||||
|
<https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-small-2023-10-24/>
|
||||||
|
|
||||||
|
##### large-scaled model, number of model parameters: 157285130, i.e., 157.29 M
|
||||||
|
|
||||||
|
| | test | dev | comment |
|
||||||
|
|------------------------|------|------|-----------------------------------------|
|
||||||
|
| greedy search | 4.49 | 4.22 | --epoch 56 --avg 23 |
|
||||||
|
| modified beam search | 4.28 | 4.03 | --epoch 56 --avg 23 |
|
||||||
|
| fast beam search | 4.44 | 4.18 | --epoch 56 --avg 23 |
|
||||||
|
|
||||||
|
Command for training is:
|
||||||
|
```bash
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1"
|
||||||
|
|
||||||
|
./zipformer/train.py \
|
||||||
|
--world-size 2 \
|
||||||
|
--num-epochs 60 \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--context-size 1 \
|
||||||
|
--exp-dir ./zipformer/exp-large \
|
||||||
|
--enable-musan 0 \
|
||||||
|
--lr-batches 7500 \
|
||||||
|
--lr-epochs 18 \
|
||||||
|
--spec-aug-time-warp-factor 20 \
|
||||||
|
--num-encoder-layers 2,2,4,5,4,2 \
|
||||||
|
--feedforward-dim 512,768,1536,2048,1536,768 \
|
||||||
|
--encoder-dim 192,256,512,768,512,256 \
|
||||||
|
--encoder-unmasked-dim 192,192,256,320,256,192 \
|
||||||
|
--max-duration 800
|
||||||
|
```
|
||||||
|
|
||||||
|
Command for decoding is:
|
||||||
|
```bash
|
||||||
|
for m in greedy_search modified_beam_search fast_beam_search ; do
|
||||||
|
./zipformer/decode.py \
|
||||||
|
--epoch 56 \
|
||||||
|
--avg 23 \
|
||||||
|
--exp-dir ./zipformer/exp-large \
|
||||||
|
--lang-dir data/lang_char \
|
||||||
|
--context-size 1 \
|
||||||
|
--decoding-method $m \
|
||||||
|
--num-encoder-layers 2,2,4,5,4,2 \
|
||||||
|
--feedforward-dim 512,768,1536,2048,1536,768 \
|
||||||
|
--encoder-dim 192,256,512,768,512,256 \
|
||||||
|
--encoder-unmasked-dim 192,192,256,320,256,192
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
Pretrained models, training logs, decoding logs, tensorboard and decoding results
|
||||||
|
are available at
|
||||||
|
<https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-large-2023-10-24/>
|
||||||
|
|
||||||
|
#### Pruned transducer stateless 7 streaming
|
||||||
|
[./pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming)
|
||||||
|
|
||||||
|
It's Streaming version of Zipformer1 with Pruned RNNT loss.
|
||||||
|
|
||||||
|
| | test | dev | comment |
|
||||||
|
|------------------------|------|------|---------------------------------------|
|
||||||
|
| greedy search | 6.95 | 6.29 | --epoch 44 --avg 15 --max-duration 600 |
|
||||||
|
| modified beam search | 6.51 | 5.90 | --epoch 44 --avg 15 --max-duration 600 |
|
||||||
|
| fast beam search | 6.73 | 6.09 | --epoch 44 --avg 15 --max-duration 600 |
|
||||||
|
|
||||||
|
Training command is:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./prepare.sh
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1"
|
||||||
|
|
||||||
|
./pruned_transducer_stateless7_streaming/train.py \
|
||||||
|
--world-size 2 \
|
||||||
|
--num-epochs 50 \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--context-size 1 \
|
||||||
|
--max-duration 800 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||||
|
--enable-musan 0 \
|
||||||
|
--spec-aug-time-warp-factor 20
|
||||||
|
```
|
||||||
|
|
||||||
|
**Caution**: It uses `--context-size=1`.
|
||||||
|
|
||||||
|
The decoding command is:
|
||||||
|
```bash
|
||||||
|
for m in greedy_search modified_beam_search fast_beam_search ; do
|
||||||
|
./pruned_transducer_stateless7_streaming/decode.py \
|
||||||
|
--epoch 44 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||||
|
--lang-dir data/lang_char \
|
||||||
|
--context-size 1 \
|
||||||
|
--decoding-method $m
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
Pretrained models, training logs, decoding logs, tensorboard and decoding results
|
||||||
|
are available at
|
||||||
|
<https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-pruned-transducer-stateless7-streaming-2023-10-16/>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#### Pruned transducer stateless 7
|
#### Pruned transducer stateless 7
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ set -eou pipefail
|
|||||||
nj=15
|
nj=15
|
||||||
stage=-1
|
stage=-1
|
||||||
stop_stage=11
|
stop_stage=11
|
||||||
|
perturb_speed=true
|
||||||
|
|
||||||
# We assume dl_dir (download dir) contains the following
|
# We assume dl_dir (download dir) contains the following
|
||||||
# directories and files. If not, they will be downloaded
|
# directories and files. If not, they will be downloaded
|
||||||
@ -114,7 +115,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
|||||||
log "Stage 3: Compute fbank for aishell"
|
log "Stage 3: Compute fbank for aishell"
|
||||||
if [ ! -f data/fbank/.aishell.done ]; then
|
if [ ! -f data/fbank/.aishell.done ]; then
|
||||||
mkdir -p data/fbank
|
mkdir -p data/fbank
|
||||||
./local/compute_fbank_aishell.py --perturb-speed True
|
./local/compute_fbank_aishell.py --perturb-speed ${perturb_speed}
|
||||||
touch data/fbank/.aishell.done
|
touch data/fbank/.aishell.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
@ -204,10 +205,6 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
|||||||
if [ ! -f $lang_char_dir/L_disambig.pt ]; then
|
if [ ! -f $lang_char_dir/L_disambig.pt ]; then
|
||||||
./local/prepare_char.py --lang-dir $lang_char_dir
|
./local/prepare_char.py --lang-dir $lang_char_dir
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ ! -f $lang_char_dir/HLG.fst ]; then
|
|
||||||
./local/prepare_lang_fst.py --lang-dir $lang_phone_dir --ngram-G ./data/lm/G_3_gram.fst.txt
|
|
||||||
fi
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||||
@ -246,7 +243,7 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
|||||||
-lm data/lm/3-gram.unpruned.arpa
|
-lm data/lm/3-gram.unpruned.arpa
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# We assume you have install kaldilm, if not, please install
|
# We assume you have installed kaldilm, if not, please install
|
||||||
# it using: pip install kaldilm
|
# it using: pip install kaldilm
|
||||||
if [ ! -f data/lm/G_3_gram_char.fst.txt ]; then
|
if [ ! -f data/lm/G_3_gram_char.fst.txt ]; then
|
||||||
# It is used in building HLG
|
# It is used in building HLG
|
||||||
@ -262,6 +259,12 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
|||||||
--max-order=3 \
|
--max-order=3 \
|
||||||
data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram_char.fst.txt
|
data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram_char.fst.txt
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [ ! -f $lang_char_dir/HLG.fst ]; then
|
||||||
|
./local/prepare_lang_fst.py \
|
||||||
|
--lang-dir $lang_char_dir \
|
||||||
|
--ngram-G ./data/lm/G_3_gram_char.fst.txt
|
||||||
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||||
|
@ -641,7 +641,7 @@ def main():
|
|||||||
contexts_text.append(line.strip())
|
contexts_text.append(line.strip())
|
||||||
contexts = graph_compiler.texts_to_ids(contexts_text)
|
contexts = graph_compiler.texts_to_ids(contexts_text)
|
||||||
context_graph = ContextGraph(params.context_score)
|
context_graph = ContextGraph(params.context_score)
|
||||||
context_graph.build(contexts)
|
context_graph.build([(c, 0.0) for c in contexts])
|
||||||
else:
|
else:
|
||||||
context_graph = None
|
context_graph = None
|
||||||
else:
|
else:
|
||||||
|
@ -1234,6 +1234,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
raise RuntimeError("Please don't use this file directly!")
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
AsrDataModule.add_arguments(parser)
|
AsrDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
@ -56,7 +56,7 @@ import torch.nn as nn
|
|||||||
from decoder2 import Decoder
|
from decoder2 import Decoder
|
||||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train2 import add_model_arguments, get_params, get_transducer_model
|
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
|
||||||
from zipformer import Zipformer
|
from zipformer import Zipformer
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
|
@ -703,7 +703,7 @@ def compute_loss(
|
|||||||
if batch_idx_train >= warm_step
|
if batch_idx_train >= warm_step
|
||||||
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
||||||
)
|
)
|
||||||
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
|
1
egs/aishell/ASR/pruned_transducer_stateless7_streaming/README.md
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7_streaming/README.md
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/README.md
|
@ -0,0 +1 @@
|
|||||||
|
../pruned_transducer_stateless2/asr_datamodule.py
|
@ -0,0 +1 @@
|
|||||||
|
../pruned_transducer_stateless2/beam_search.py
|
735
egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py
Executable file
735
egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py
Executable file
@ -0,0 +1,735 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||||
|
# Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
(1) greedy search
|
||||||
|
./pruned_transducer_stateless7_streaming/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decode-chunk-len 32 \
|
||||||
|
--decoding-method greedy_search
|
||||||
|
|
||||||
|
(2) beam search (not recommended)
|
||||||
|
./pruned_transducer_stateless7_streaming/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decode-chunk-len 32 \
|
||||||
|
--decoding-method beam_search \
|
||||||
|
--beam-size 4
|
||||||
|
|
||||||
|
(3) modified beam search
|
||||||
|
./pruned_transducer_stateless7_streaming/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decode-chunk-len 32 \
|
||||||
|
--decoding-method modified_beam_search \
|
||||||
|
--beam-size 4
|
||||||
|
|
||||||
|
(4) fast beam search (one best)
|
||||||
|
./pruned_transducer_stateless7_streaming/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decode-chunk-len 32 \
|
||||||
|
--decoding-method fast_beam_search \
|
||||||
|
--beam 20.0 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--max-states 64
|
||||||
|
|
||||||
|
(5) fast beam search (nbest)
|
||||||
|
./pruned_transducer_stateless7_streaming/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decode-chunk-len 32 \
|
||||||
|
--decoding-method fast_beam_search_nbest \
|
||||||
|
--beam 20.0 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--max-states 64 \
|
||||||
|
--num-paths 200 \
|
||||||
|
--nbest-scale 0.5
|
||||||
|
|
||||||
|
(6) fast beam search (nbest oracle WER)
|
||||||
|
./pruned_transducer_stateless7_streaming/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decode-chunk-len 32 \
|
||||||
|
--decoding-method fast_beam_search_nbest_oracle \
|
||||||
|
--beam 20.0 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--max-states 64 \
|
||||||
|
--num-paths 200 \
|
||||||
|
--nbest-scale 0.5
|
||||||
|
|
||||||
|
(7) fast beam search (with LG)
|
||||||
|
./pruned_transducer_stateless7_streaming/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decode-chunk-len 32 \
|
||||||
|
--decoding-method fast_beam_search_nbest_LG \
|
||||||
|
--beam 20.0 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--max-states 64
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from asr_datamodule import AishellAsrDataModule
|
||||||
|
from beam_search import (
|
||||||
|
beam_search,
|
||||||
|
fast_beam_search_one_best,
|
||||||
|
greedy_search,
|
||||||
|
greedy_search_batch,
|
||||||
|
modified_beam_search,
|
||||||
|
)
|
||||||
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall import ContextGraph
|
||||||
|
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
setup_logger,
|
||||||
|
store_transcripts,
|
||||||
|
str2bool,
|
||||||
|
write_error_stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="pruned_transducer_stateless3/exp",
|
||||||
|
help="The experiment dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_char",
|
||||||
|
help="The lang dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoding-method",
|
||||||
|
type=str,
|
||||||
|
default="greedy_search",
|
||||||
|
help="""Possible values are:
|
||||||
|
- greedy_search
|
||||||
|
- beam_search
|
||||||
|
- modified_beam_search
|
||||||
|
- fast_beam_search
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam-size",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""An integer indicating how many candidates we will keep for each
|
||||||
|
frame. Used only when --decoding-method is beam_search or
|
||||||
|
modified_beam_search.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam",
|
||||||
|
type=float,
|
||||||
|
default=4,
|
||||||
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
|
`beam` in Kaldi.
|
||||||
|
Used only when --decoding-method is fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-contexts",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""Used only when --decoding-method is
|
||||||
|
fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-states",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="""Used only when --decoding-method is
|
||||||
|
fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-sym-per-frame",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="""Maximum number of symbols per frame.
|
||||||
|
Used only when --decoding_method is greedy_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-score",
|
||||||
|
type=float,
|
||||||
|
default=2,
|
||||||
|
help="""
|
||||||
|
The bonus score of each token for the context biasing words/phrases.
|
||||||
|
Used only when --decoding_method is modified_beam_search.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-file",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""
|
||||||
|
The path of the context biasing lists, one word/phrase each line
|
||||||
|
Used only when --decoding_method is modified_beam_search.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def decode_one_batch(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
token_table: k2.SymbolTable,
|
||||||
|
batch: dict,
|
||||||
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
|
) -> Dict[str, List[List[str]]]:
|
||||||
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
|
following format:
|
||||||
|
|
||||||
|
- key: It indicates the setting used for decoding. For example,
|
||||||
|
if greedy_search is used, it would be "greedy_search"
|
||||||
|
If beam search with a beam size of 7 is used, it would be
|
||||||
|
"beam_7"
|
||||||
|
- value: It contains the decoding result. `len(value)` equals to
|
||||||
|
batch size. `value[i]` is the decoding result for the i-th
|
||||||
|
utterance in the given batch.
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
It's the return value of :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
token_table:
|
||||||
|
It maps token ID to a string.
|
||||||
|
batch:
|
||||||
|
It is the return value from iterating
|
||||||
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
|
for the format of the `batch`.
|
||||||
|
decoding_graph:
|
||||||
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
|
only when --decoding_method is fast_beam_search.
|
||||||
|
Returns:
|
||||||
|
Return the decoding result. See above description for the format of
|
||||||
|
the returned dict.
|
||||||
|
"""
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
feature = batch["inputs"]
|
||||||
|
assert feature.ndim == 3
|
||||||
|
|
||||||
|
feature = feature.to(device)
|
||||||
|
# at entry, feature is (N, T, C)
|
||||||
|
|
||||||
|
supervisions = batch["supervisions"]
|
||||||
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
|
feature_lens += 30
|
||||||
|
feature = torch.nn.functional.pad(
|
||||||
|
feature,
|
||||||
|
pad=(0, 0, 0, 30),
|
||||||
|
value=LOG_EPS,
|
||||||
|
)
|
||||||
|
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||||
|
|
||||||
|
if params.decoding_method == "fast_beam_search":
|
||||||
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
)
|
||||||
|
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
|
hyp_tokens = greedy_search_batch(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
)
|
||||||
|
elif params.decoding_method == "modified_beam_search":
|
||||||
|
hyp_tokens = modified_beam_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam_size,
|
||||||
|
context_graph=context_graph,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hyp_tokens = []
|
||||||
|
batch_size = encoder_out.size(0)
|
||||||
|
for i in range(batch_size):
|
||||||
|
# fmt: off
|
||||||
|
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||||
|
# fmt: on
|
||||||
|
if params.decoding_method == "greedy_search":
|
||||||
|
hyp = greedy_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out_i,
|
||||||
|
max_sym_per_frame=params.max_sym_per_frame,
|
||||||
|
)
|
||||||
|
elif params.decoding_method == "beam_search":
|
||||||
|
hyp = beam_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out_i,
|
||||||
|
beam=params.beam_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
|
)
|
||||||
|
hyp_tokens.append(hyp)
|
||||||
|
|
||||||
|
hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens]
|
||||||
|
|
||||||
|
if params.decoding_method == "greedy_search":
|
||||||
|
return {"greedy_search": hyps}
|
||||||
|
elif params.decoding_method == "fast_beam_search":
|
||||||
|
return {
|
||||||
|
(
|
||||||
|
f"beam_{params.beam}_"
|
||||||
|
f"max_contexts_{params.max_contexts}_"
|
||||||
|
f"max_states_{params.max_states}"
|
||||||
|
): hyps
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
key = f"beam_size_{params.beam_size}"
|
||||||
|
if params.has_contexts:
|
||||||
|
key += f"-context-score-{params.context_score}"
|
||||||
|
else:
|
||||||
|
key += "-no-context-words"
|
||||||
|
return {key: hyps}
|
||||||
|
|
||||||
|
|
||||||
|
def decode_dataset(
|
||||||
|
dl: torch.utils.data.DataLoader,
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
token_table: Optional[k2.SymbolTable] = None,
|
||||||
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
|
"""Decode dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dl:
|
||||||
|
PyTorch's dataloader containing the dataset to decode.
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
token_table:
|
||||||
|
It maps a token ID to a string.
|
||||||
|
decoding_graph:
|
||||||
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
|
only when --decoding_method is fast_beam_search.
|
||||||
|
Returns:
|
||||||
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
|
Its value is a list of tuples. Each tuple contains two elements:
|
||||||
|
The first is the reference transcript, and the second is the
|
||||||
|
predicted result.
|
||||||
|
"""
|
||||||
|
num_cuts = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_batches = len(dl)
|
||||||
|
except TypeError:
|
||||||
|
num_batches = "?"
|
||||||
|
|
||||||
|
if params.decoding_method == "greedy_search":
|
||||||
|
log_interval = 50
|
||||||
|
else:
|
||||||
|
log_interval = 20
|
||||||
|
|
||||||
|
results = defaultdict(list)
|
||||||
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
|
hyps_dict = decode_one_batch(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
token_table=token_table,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
context_graph=context_graph,
|
||||||
|
batch=batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
for name, hyps in hyps_dict.items():
|
||||||
|
this_batch = []
|
||||||
|
assert len(hyps) == len(texts)
|
||||||
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
|
ref_words = ref_text.split()
|
||||||
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
|
num_cuts += len(texts)
|
||||||
|
|
||||||
|
if batch_idx % log_interval == 0:
|
||||||
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
|
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def save_results(
|
||||||
|
params: AttributeDict,
|
||||||
|
test_set_name: str,
|
||||||
|
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||||
|
):
|
||||||
|
test_set_wers = dict()
|
||||||
|
for key, results in results_dict.items():
|
||||||
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
|
results = sorted(results)
|
||||||
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
|
# ref/hyp pairs.
|
||||||
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
|
with open(errs_filename, "w") as f:
|
||||||
|
wer = write_error_stats(
|
||||||
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
|
)
|
||||||
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
|
with open(errs_info, "w") as f:
|
||||||
|
print("settings\tWER", file=f)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
print("{}\t{}".format(key, val), file=f)
|
||||||
|
|
||||||
|
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||||
|
note = "\tbest for {}".format(test_set_name)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
s += "{}\t{}{}\n".format(key, val, note)
|
||||||
|
note = ""
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
AishellAsrDataModule.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
args.lang_dir = Path(args.lang_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
assert params.decoding_method in (
|
||||||
|
"greedy_search",
|
||||||
|
"beam_search",
|
||||||
|
"fast_beam_search",
|
||||||
|
"modified_beam_search",
|
||||||
|
)
|
||||||
|
|
||||||
|
if os.path.exists(params.context_file):
|
||||||
|
params.has_contexts = True
|
||||||
|
else:
|
||||||
|
params.has_contexts = False
|
||||||
|
|
||||||
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||||
|
else:
|
||||||
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
|
if "fast_beam_search" in params.decoding_method:
|
||||||
|
params.suffix += f"-beam-{params.beam}"
|
||||||
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
|
elif "beam_search" in params.decoding_method:
|
||||||
|
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
|
if params.has_contexts:
|
||||||
|
params.suffix += f"-context-score-{params.context_score}"
|
||||||
|
else:
|
||||||
|
params.suffix += "-no-contexts-words"
|
||||||
|
else:
|
||||||
|
params.suffix += f"-context-{params.context_size}"
|
||||||
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
|
if params.use_averaged_model:
|
||||||
|
params.suffix += "-use-averaged-model"
|
||||||
|
|
||||||
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
|
logging.info("Decoding started")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
params.blank_id = 0
|
||||||
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
|
|
||||||
|
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||||
|
lexicon=lexicon,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_transducer_model(params)
|
||||||
|
assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, (
|
||||||
|
model.encoder.decode_chunk_size,
|
||||||
|
params.decode_chunk_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
if params.decoding_method == "fast_beam_search":
|
||||||
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
|
else:
|
||||||
|
decoding_graph = None
|
||||||
|
|
||||||
|
if params.decoding_method == "modified_beam_search":
|
||||||
|
if os.path.exists(params.context_file):
|
||||||
|
contexts_text = []
|
||||||
|
for line in open(params.context_file).readlines():
|
||||||
|
contexts_text.append(line.strip())
|
||||||
|
contexts = graph_compiler.texts_to_ids(contexts_text)
|
||||||
|
context_graph = ContextGraph(params.context_score)
|
||||||
|
context_graph.build([(c, 0.0) for c in contexts])
|
||||||
|
else:
|
||||||
|
context_graph = None
|
||||||
|
else:
|
||||||
|
context_graph = None
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
|
aishell = AishellAsrDataModule(args)
|
||||||
|
|
||||||
|
test_cuts = aishell.test_cuts()
|
||||||
|
dev_cuts = aishell.valid_cuts()
|
||||||
|
|
||||||
|
test_dl = aishell.test_dataloaders(test_cuts)
|
||||||
|
dev_dl = aishell.test_dataloaders(dev_cuts)
|
||||||
|
|
||||||
|
test_sets = ["test", "dev"]
|
||||||
|
test_dls = [test_dl, dev_dl]
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
for test_set, test_dl in zip(test_sets, test_dls):
|
||||||
|
start = time.time()
|
||||||
|
results_dict = decode_dataset(
|
||||||
|
dl=test_dl,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
token_table=lexicon.token_table,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
context_graph=context_graph,
|
||||||
|
)
|
||||||
|
logging.info(f"Elasped time for {test_set}: {time.time() - start}")
|
||||||
|
|
||||||
|
save_results(
|
||||||
|
params=params,
|
||||||
|
test_set_name=test_set,
|
||||||
|
results_dict=results_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py
|
@ -0,0 +1 @@
|
|||||||
|
../pruned_transducer_stateless7/decoder.py
|
1254
egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
Executable file
1254
egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
Executable file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1 @@
|
|||||||
|
../pruned_transducer_stateless2/encoder_interface.py
|
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py
|
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py
|
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py
|
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py
|
1
egs/aishell/ASR/pruned_transducer_stateless7_streaming/export.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7_streaming/export.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export.py
|
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py
|
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py
|
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py
|
1
egs/aishell/ASR/pruned_transducer_stateless7_streaming/joiner.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7_streaming/joiner.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../pruned_transducer_stateless7/joiner.py
|
1
egs/aishell/ASR/pruned_transducer_stateless7_streaming/model.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7_streaming/model.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../pruned_transducer_stateless7/model.py
|
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py
|
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py
|
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py
|
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py
|
1
egs/aishell/ASR/pruned_transducer_stateless7_streaming/optim.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7_streaming/optim.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../pruned_transducer_stateless7/optim.py
|
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py
|
@ -0,0 +1 @@
|
|||||||
|
../pruned_transducer_stateless7/scaling.py
|
@ -0,0 +1 @@
|
|||||||
|
../pruned_transducer_stateless7/scaling_converter.py
|
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py
|
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py
|
627
egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
Executable file
627
egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
Executable file
@ -0,0 +1,627 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
./pruned_transducer_stateless7_streaming/streaming_decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--decode-chunk-len 32 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||||
|
--decoding-method greedy_search \
|
||||||
|
--num-decode-streams 2000
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from asr_datamodule import AishellAsrDataModule
|
||||||
|
from decode_stream import DecodeStream
|
||||||
|
from kaldifeat import Fbank, FbankOptions
|
||||||
|
from lhotse import CutSet
|
||||||
|
from streaming_beam_search import (
|
||||||
|
fast_beam_search_one_best,
|
||||||
|
greedy_search,
|
||||||
|
modified_beam_search,
|
||||||
|
)
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
from zipformer import stack_states, unstack_states
|
||||||
|
|
||||||
|
from icefall import ContextGraph
|
||||||
|
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
setup_logger,
|
||||||
|
store_transcripts,
|
||||||
|
str2bool,
|
||||||
|
write_error_stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=28,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="pruned_transducer_stateless7_streaming/exp",
|
||||||
|
help="The experiment dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoding-method",
|
||||||
|
type=str,
|
||||||
|
default="greedy_search",
|
||||||
|
help="""Supported decoding methods are:
|
||||||
|
greedy_search
|
||||||
|
modified_beam_search
|
||||||
|
fast_beam_search
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_active_paths",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""An interger indicating how many candidates we will keep for each
|
||||||
|
frame. Used only when --decoding-method is modified_beam_search.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam",
|
||||||
|
type=float,
|
||||||
|
default=4,
|
||||||
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
|
`beam` in Kaldi.
|
||||||
|
Used only when --decoding-method is fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-contexts",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""Used only when --decoding-method is
|
||||||
|
fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-states",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
help="""Used only when --decoding-method is
|
||||||
|
fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-decode-streams",
|
||||||
|
type=int,
|
||||||
|
default=2000,
|
||||||
|
help="The number of streams that can be decoded parallel.",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def decode_one_chunk(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
decode_streams: List[DecodeStream],
|
||||||
|
) -> List[int]:
|
||||||
|
"""Decode one chunk frames of features for each decode_streams and
|
||||||
|
return the indexes of finished streams in a List.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
It's the return value of :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
decode_streams:
|
||||||
|
A List of DecodeStream, each belonging to a utterance.
|
||||||
|
Returns:
|
||||||
|
Return a List containing which DecodeStreams are finished.
|
||||||
|
"""
|
||||||
|
device = model.device
|
||||||
|
|
||||||
|
features = []
|
||||||
|
feature_lens = []
|
||||||
|
states = []
|
||||||
|
processed_lens = []
|
||||||
|
|
||||||
|
for stream in decode_streams:
|
||||||
|
feat, feat_len = stream.get_feature_frames(params.decode_chunk_len)
|
||||||
|
features.append(feat)
|
||||||
|
feature_lens.append(feat_len)
|
||||||
|
states.append(stream.states)
|
||||||
|
processed_lens.append(stream.done_frames)
|
||||||
|
|
||||||
|
feature_lens = torch.tensor(feature_lens, device=device)
|
||||||
|
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
|
||||||
|
|
||||||
|
# We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling
|
||||||
|
# factor in encoders is 8.
|
||||||
|
# After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8.
|
||||||
|
tail_length = 23
|
||||||
|
if features.size(1) < tail_length:
|
||||||
|
pad_length = tail_length - features.size(1)
|
||||||
|
feature_lens += pad_length
|
||||||
|
features = torch.nn.functional.pad(
|
||||||
|
features,
|
||||||
|
(0, 0, 0, pad_length),
|
||||||
|
mode="constant",
|
||||||
|
value=LOG_EPS,
|
||||||
|
)
|
||||||
|
|
||||||
|
states = stack_states(states)
|
||||||
|
processed_lens = torch.tensor(processed_lens, device=device)
|
||||||
|
|
||||||
|
encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward(
|
||||||
|
x=features,
|
||||||
|
x_lens=feature_lens,
|
||||||
|
states=states,
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||||
|
|
||||||
|
if params.decoding_method == "greedy_search":
|
||||||
|
greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
|
||||||
|
elif params.decoding_method == "fast_beam_search":
|
||||||
|
processed_lens = processed_lens + encoder_out_lens
|
||||||
|
fast_beam_search_one_best(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
processed_lens=processed_lens,
|
||||||
|
streams=decode_streams,
|
||||||
|
beam=params.beam,
|
||||||
|
max_states=params.max_states,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
)
|
||||||
|
elif params.decoding_method == "modified_beam_search":
|
||||||
|
modified_beam_search(
|
||||||
|
model=model,
|
||||||
|
streams=decode_streams,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
num_active_paths=params.num_active_paths,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||||
|
|
||||||
|
states = unstack_states(new_states)
|
||||||
|
|
||||||
|
finished_streams = []
|
||||||
|
for i in range(len(decode_streams)):
|
||||||
|
decode_streams[i].states = states[i]
|
||||||
|
decode_streams[i].done_frames += encoder_out_lens[i]
|
||||||
|
if decode_streams[i].done:
|
||||||
|
finished_streams.append(i)
|
||||||
|
|
||||||
|
return finished_streams
|
||||||
|
|
||||||
|
|
||||||
|
def decode_dataset(
|
||||||
|
cuts: CutSet,
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
token_table: Optional[k2.SymbolTable] = None,
|
||||||
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
|
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||||
|
"""Decode dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cuts:
|
||||||
|
Lhotse Cutset containing the dataset to decode.
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
sp:
|
||||||
|
The BPE model.
|
||||||
|
decoding_graph:
|
||||||
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
|
only when --decoding_method is fast_beam_search.
|
||||||
|
Returns:
|
||||||
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
|
Its value is a list of tuples. Each tuple contains two elements:
|
||||||
|
The first is the reference transcript, and the second is the
|
||||||
|
predicted result.
|
||||||
|
"""
|
||||||
|
device = model.device
|
||||||
|
|
||||||
|
opts = FbankOptions()
|
||||||
|
opts.device = device
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
opts.frame_opts.samp_freq = 16000
|
||||||
|
opts.mel_opts.num_bins = 80
|
||||||
|
|
||||||
|
log_interval = 50
|
||||||
|
|
||||||
|
decode_results = []
|
||||||
|
# Contain decode streams currently running.
|
||||||
|
decode_streams = []
|
||||||
|
for num, cut in enumerate(cuts):
|
||||||
|
# each utterance has a DecodeStream.
|
||||||
|
initial_states = model.encoder.get_init_state(device=device)
|
||||||
|
decode_stream = DecodeStream(
|
||||||
|
params=params,
|
||||||
|
cut_id=cut.id,
|
||||||
|
initial_states=initial_states,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
audio: np.ndarray = cut.load_audio()
|
||||||
|
# audio.shape: (1, num_samples)
|
||||||
|
assert len(audio.shape) == 2
|
||||||
|
assert audio.shape[0] == 1, "Should be single channel"
|
||||||
|
assert audio.dtype == np.float32, audio.dtype
|
||||||
|
|
||||||
|
# The trained model is using normalized samples
|
||||||
|
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
|
||||||
|
|
||||||
|
samples = torch.from_numpy(audio).squeeze(0)
|
||||||
|
|
||||||
|
fbank = Fbank(opts)
|
||||||
|
feature = fbank(samples.to(device))
|
||||||
|
decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len)
|
||||||
|
decode_stream.ground_truth = cut.supervisions[0].text
|
||||||
|
|
||||||
|
decode_streams.append(decode_stream)
|
||||||
|
|
||||||
|
while len(decode_streams) >= params.num_decode_streams:
|
||||||
|
finished_streams = decode_one_chunk(
|
||||||
|
params=params, model=model, decode_streams=decode_streams
|
||||||
|
)
|
||||||
|
for i in sorted(finished_streams, reverse=True):
|
||||||
|
decode_results.append(
|
||||||
|
(
|
||||||
|
decode_streams[i].id,
|
||||||
|
decode_streams[i].ground_truth.split(),
|
||||||
|
[
|
||||||
|
token_table[result]
|
||||||
|
for result in decode_streams[i].decoding_result()
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
del decode_streams[i]
|
||||||
|
|
||||||
|
if num % log_interval == 0:
|
||||||
|
logging.info(f"Cuts processed until now is {num}.")
|
||||||
|
|
||||||
|
# decode final chunks of last sequences
|
||||||
|
while len(decode_streams):
|
||||||
|
finished_streams = decode_one_chunk(
|
||||||
|
params=params, model=model, decode_streams=decode_streams
|
||||||
|
)
|
||||||
|
for i in sorted(finished_streams, reverse=True):
|
||||||
|
decode_results.append(
|
||||||
|
(
|
||||||
|
decode_streams[i].id,
|
||||||
|
decode_streams[i].ground_truth.split(),
|
||||||
|
[
|
||||||
|
token_table[result]
|
||||||
|
for result in decode_streams[i].decoding_result()
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
del decode_streams[i]
|
||||||
|
|
||||||
|
if params.decoding_method == "greedy_search":
|
||||||
|
key = "greedy_search"
|
||||||
|
elif params.decoding_method == "fast_beam_search":
|
||||||
|
key = (
|
||||||
|
f"beam_{params.beam}_"
|
||||||
|
f"max_contexts_{params.max_contexts}_"
|
||||||
|
f"max_states_{params.max_states}"
|
||||||
|
)
|
||||||
|
elif params.decoding_method == "modified_beam_search":
|
||||||
|
key = f"num_active_paths_{params.num_active_paths}"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||||
|
return {key: decode_results}
|
||||||
|
|
||||||
|
|
||||||
|
def save_results(
|
||||||
|
params: AttributeDict,
|
||||||
|
test_set_name: str,
|
||||||
|
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
|
||||||
|
):
|
||||||
|
test_set_wers = dict()
|
||||||
|
for key, results in results_dict.items():
|
||||||
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
|
results = sorted(results)
|
||||||
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
|
# ref/hyp pairs.
|
||||||
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
|
with open(errs_filename, "w") as f:
|
||||||
|
wer = write_error_stats(
|
||||||
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
|
)
|
||||||
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
|
with open(errs_info, "w") as f:
|
||||||
|
print("settings\tWER", file=f)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
print("{}\t{}".format(key, val), file=f)
|
||||||
|
|
||||||
|
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||||
|
note = "\tbest for {}".format(test_set_name)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
s += "{}\t{}{}\n".format(key, val, note)
|
||||||
|
note = ""
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
AishellAsrDataModule.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||||
|
else:
|
||||||
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
|
# for streaming
|
||||||
|
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}"
|
||||||
|
|
||||||
|
# for fast_beam_search
|
||||||
|
if params.decoding_method == "fast_beam_search":
|
||||||
|
params.suffix += f"-beam-{params.beam}"
|
||||||
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
|
|
||||||
|
if params.use_averaged_model:
|
||||||
|
params.suffix += "-use-averaged-model"
|
||||||
|
|
||||||
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
|
logging.info("Decoding started")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
params.blank_id = 0
|
||||||
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
|
|
||||||
|
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||||
|
lexicon=lexicon,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if start >= 0:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
model.device = device
|
||||||
|
|
||||||
|
decoding_graph = None
|
||||||
|
if params.decoding_method == "fast_beam_search":
|
||||||
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
|
|
||||||
|
if params.decoding_method == "modified_beam_search":
|
||||||
|
if os.path.exists(params.context_file):
|
||||||
|
contexts_text = []
|
||||||
|
for line in open(params.context_file).readlines():
|
||||||
|
contexts_text.append(line.strip())
|
||||||
|
contexts = graph_compiler.texts_to_ids(contexts_text)
|
||||||
|
context_graph = ContextGraph(params.context_score)
|
||||||
|
context_graph.build(contexts)
|
||||||
|
else:
|
||||||
|
context_graph = None
|
||||||
|
else:
|
||||||
|
context_graph = None
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
aishell = AishellAsrDataModule(args)
|
||||||
|
|
||||||
|
test_cuts = aishell.test_cuts()
|
||||||
|
valid_cuts = aishell.valid_cuts()
|
||||||
|
|
||||||
|
test_sets = ["test", "valid"]
|
||||||
|
cuts = [test_cuts, valid_cuts]
|
||||||
|
|
||||||
|
for test_set, test_cut in zip(test_sets, cuts):
|
||||||
|
results_dict = decode_dataset(
|
||||||
|
cuts=test_cut,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
token_table=lexicon.token_table,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
context_graph=context_graph,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_results(
|
||||||
|
params=params,
|
||||||
|
test_set_name=test_set,
|
||||||
|
results_dict=results_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/test_model.py
|
1251
egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py
Executable file
1251
egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py
Executable file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py
|
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py
|
@ -70,6 +70,10 @@ class Decoder(nn.Module):
|
|||||||
groups=embedding_dim,
|
groups=embedding_dim,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
|
||||||
|
# when inference with torch.jit.script and context_size == 1
|
||||||
|
self.conv = nn.Identity()
|
||||||
|
|
||||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
|
@ -160,7 +160,7 @@ class AsrDataModule:
|
|||||||
if cuts_musan is not None:
|
if cuts_musan is not None:
|
||||||
logging.info("Enable MUSAN")
|
logging.info("Enable MUSAN")
|
||||||
transforms.append(
|
transforms.append(
|
||||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Disable MUSAN")
|
logging.info("Disable MUSAN")
|
||||||
|
0
egs/aishell/ASR/zipformer/__init__.py
Normal file
0
egs/aishell/ASR/zipformer/__init__.py
Normal file
1
egs/aishell/ASR/zipformer/asr_datamodule.py
Symbolic link
1
egs/aishell/ASR/zipformer/asr_datamodule.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../pruned_transducer_stateless2/asr_datamodule.py
|
1
egs/aishell/ASR/zipformer/beam_search.py
Symbolic link
1
egs/aishell/ASR/zipformer/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../pruned_transducer_stateless2/beam_search.py
|
814
egs/aishell/ASR/zipformer/decode.py
Executable file
814
egs/aishell/ASR/zipformer/decode.py
Executable file
@ -0,0 +1,814 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||||
|
# Zengwei Yao
|
||||||
|
# Mingshuang Luo)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
(1) greedy search
|
||||||
|
./zipformer/decode.py \
|
||||||
|
--epoch 35 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--lang-dir data/lang_char \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method greedy_search
|
||||||
|
|
||||||
|
(2) modified beam search
|
||||||
|
./zipformer/decode.py \
|
||||||
|
--epoch 35 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--lang-dir data/lang_char \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method modified_beam_search \
|
||||||
|
--beam-size 4
|
||||||
|
|
||||||
|
(3) fast beam search (trivial_graph)
|
||||||
|
./zipformer/decode.py \
|
||||||
|
--epoch 35 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--lang-dir data/lang_char \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search \
|
||||||
|
--beam 20.0 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--max-states 64
|
||||||
|
|
||||||
|
(4) fast beam search (LG)
|
||||||
|
./zipformer/decode.py \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--lang-dir data/lang_char \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search_LG \
|
||||||
|
--beam 20.0 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--max-states 64
|
||||||
|
|
||||||
|
(5) fast beam search (nbest oracle WER)
|
||||||
|
./zipformer/decode.py \
|
||||||
|
--epoch 35 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--lang-dir data/lang_char \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search_nbest_oracle \
|
||||||
|
--beam 20.0 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--max-states 64 \
|
||||||
|
--num-paths 200 \
|
||||||
|
--nbest-scale 0.5
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from asr_datamodule import AishellAsrDataModule
|
||||||
|
from beam_search import (
|
||||||
|
beam_search,
|
||||||
|
fast_beam_search_nbest,
|
||||||
|
fast_beam_search_nbest_LG,
|
||||||
|
fast_beam_search_nbest_oracle,
|
||||||
|
fast_beam_search_one_best,
|
||||||
|
greedy_search,
|
||||||
|
greedy_search_batch,
|
||||||
|
modified_beam_search,
|
||||||
|
)
|
||||||
|
from lhotse.cut import Cut
|
||||||
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
|
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
make_pad_mask,
|
||||||
|
setup_logger,
|
||||||
|
store_transcripts,
|
||||||
|
str2bool,
|
||||||
|
write_error_stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="zipformer/exp",
|
||||||
|
help="The experiment dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=Path,
|
||||||
|
default="data/lang_char",
|
||||||
|
help="The lang dir containing word table and LG graph",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoding-method",
|
||||||
|
type=str,
|
||||||
|
default="greedy_search",
|
||||||
|
help="""Possible values are:
|
||||||
|
- greedy_search
|
||||||
|
- modified_beam_search
|
||||||
|
- fast_beam_search
|
||||||
|
- fast_beam_search_LG
|
||||||
|
- fast_beam_search_nbest_oracle
|
||||||
|
If you use fast_beam_search_LG, you have to specify
|
||||||
|
`--lang-dir`, which should contain `LG.pt`.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam-size",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""An integer indicating how many candidates we will keep for each
|
||||||
|
frame. Used only when --decoding-method is beam_search or
|
||||||
|
modified_beam_search.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam",
|
||||||
|
type=float,
|
||||||
|
default=20.0,
|
||||||
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
|
`beam` in Kaldi.
|
||||||
|
Used only when --decoding-method is fast_beam_search,
|
||||||
|
fast_beam_search, fast_beam_search_LG,
|
||||||
|
and fast_beam_search_nbest_oracle
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngram-lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.01,
|
||||||
|
help="""
|
||||||
|
Used only when --decoding_method is fast_beam_search_LG.
|
||||||
|
It specifies the scale for n-gram LM scores.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ilme-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.2,
|
||||||
|
help="""
|
||||||
|
Used only when --decoding_method is fast_beam_search_LG.
|
||||||
|
It specifies the scale for the internal language model estimation.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-contexts",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="""Used only when --decoding-method is
|
||||||
|
fast_beam_search, fast_beam_search, fast_beam_search_LG,
|
||||||
|
and fast_beam_search_nbest_oracle""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-states",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="""Used only when --decoding-method is
|
||||||
|
fast_beam_search, fast_beam_search, fast_beam_search_LG,
|
||||||
|
and fast_beam_search_nbest_oracle""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-sym-per-frame",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="""Maximum number of symbols per frame.
|
||||||
|
Used only when --decoding_method is greedy_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-paths",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help="""Number of paths for nbest decoding.
|
||||||
|
Used only when the decoding method is fast_beam_search_nbest_oracle""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nbest-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="""Scale applied to lattice scores when computing nbest paths.
|
||||||
|
Used only when the decoding method is and fast_beam_search_nbest_oracle""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--blank-penalty",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""
|
||||||
|
The penalty applied on blank symbol during decoding.
|
||||||
|
Note: It is a positive value that would be applied to logits like
|
||||||
|
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||||
|
[batch_size, vocab] and blank id is 0).
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def decode_one_batch(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
lexicon: Lexicon,
|
||||||
|
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||||
|
batch: dict,
|
||||||
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
) -> Dict[str, List[List[str]]]:
|
||||||
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
|
following format:
|
||||||
|
|
||||||
|
- key: It indicates the setting used for decoding. For example,
|
||||||
|
if greedy_search is used, it would be "greedy_search"
|
||||||
|
If beam search with a beam size of 7 is used, it would be
|
||||||
|
"beam_7"
|
||||||
|
- value: It contains the decoding result. `len(value)` equals to
|
||||||
|
batch size. `value[i]` is the decoding result for the i-th
|
||||||
|
utterance in the given batch.
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
It's the return value of :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
batch:
|
||||||
|
It is the return value from iterating
|
||||||
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
|
for the format of the `batch`.
|
||||||
|
decoding_graph:
|
||||||
|
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||||
|
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||||
|
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||||
|
Returns:
|
||||||
|
Return the decoding result. See above description for the format of
|
||||||
|
the returned dict.
|
||||||
|
"""
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
feature = batch["inputs"]
|
||||||
|
assert feature.ndim == 3
|
||||||
|
|
||||||
|
feature = feature.to(device)
|
||||||
|
# at entry, feature is (N, T, C)
|
||||||
|
|
||||||
|
supervisions = batch["supervisions"]
|
||||||
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
|
if params.causal:
|
||||||
|
# this seems to cause insertions at the end of the utterance if used with zipformer.
|
||||||
|
pad_len = 30
|
||||||
|
feature_lens += pad_len
|
||||||
|
feature = torch.nn.functional.pad(
|
||||||
|
feature,
|
||||||
|
pad=(0, 0, 0, pad_len),
|
||||||
|
value=LOG_EPS,
|
||||||
|
)
|
||||||
|
|
||||||
|
x, x_lens = model.encoder_embed(feature, feature_lens)
|
||||||
|
|
||||||
|
src_key_padding_mask = make_pad_mask(x_lens)
|
||||||
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
|
encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask)
|
||||||
|
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
|
hyps = []
|
||||||
|
|
||||||
|
if params.decoding_method == "fast_beam_search":
|
||||||
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
blank_penalty=params.blank_penalty,
|
||||||
|
)
|
||||||
|
for i in range(encoder_out.size(0)):
|
||||||
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
|
elif params.decoding_method == "fast_beam_search_LG":
|
||||||
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
blank_penalty=params.blank_penalty,
|
||||||
|
ilme_scale=params.ilme_scale,
|
||||||
|
)
|
||||||
|
for hyp in hyp_tokens:
|
||||||
|
sentence = "".join([lexicon.word_table[i] for i in hyp])
|
||||||
|
hyps.append(list(sentence))
|
||||||
|
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||||
|
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
ref_texts=graph_compiler.texts_to_ids(supervisions["text"]),
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
blank_penalty=params.blank_penalty,
|
||||||
|
)
|
||||||
|
for i in range(encoder_out.size(0)):
|
||||||
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
|
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
|
hyp_tokens = greedy_search_batch(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
blank_penalty=params.blank_penalty,
|
||||||
|
)
|
||||||
|
for i in range(encoder_out.size(0)):
|
||||||
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
|
elif params.decoding_method == "modified_beam_search":
|
||||||
|
hyp_tokens = modified_beam_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
blank_penalty=params.blank_penalty,
|
||||||
|
beam=params.beam_size,
|
||||||
|
)
|
||||||
|
for i in range(encoder_out.size(0)):
|
||||||
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
|
else:
|
||||||
|
batch_size = encoder_out.size(0)
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
# fmt: off
|
||||||
|
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
|
||||||
|
# fmt: on
|
||||||
|
if params.decoding_method == "greedy_search":
|
||||||
|
hyp = greedy_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out_i,
|
||||||
|
max_sym_per_frame=params.max_sym_per_frame,
|
||||||
|
blank_penalty=params.blank_penalty,
|
||||||
|
)
|
||||||
|
elif params.decoding_method == "beam_search":
|
||||||
|
hyp = beam_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out_i,
|
||||||
|
beam=params.beam_size,
|
||||||
|
blank_penalty=params.blank_penalty,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
|
)
|
||||||
|
hyps.append([lexicon.token_table[idx] for idx in hyp])
|
||||||
|
|
||||||
|
key = f"blank_penalty_{params.blank_penalty}"
|
||||||
|
if params.decoding_method == "greedy_search":
|
||||||
|
return {"greedy_search_" + key: hyps}
|
||||||
|
elif "fast_beam_search" in params.decoding_method:
|
||||||
|
key += f"_beam_{params.beam}_"
|
||||||
|
key += f"max_contexts_{params.max_contexts}_"
|
||||||
|
key += f"max_states_{params.max_states}"
|
||||||
|
if "nbest" in params.decoding_method:
|
||||||
|
key += f"_num_paths_{params.num_paths}_"
|
||||||
|
key += f"nbest_scale_{params.nbest_scale}"
|
||||||
|
if "LG" in params.decoding_method:
|
||||||
|
key += f"_ilme_scale_{params.ilme_scale}"
|
||||||
|
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||||
|
|
||||||
|
return {key: hyps}
|
||||||
|
else:
|
||||||
|
return {f"beam_size_{params.beam_size}_" + key: hyps}
|
||||||
|
|
||||||
|
|
||||||
|
def decode_dataset(
|
||||||
|
dl: torch.utils.data.DataLoader,
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
lexicon: Lexicon,
|
||||||
|
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||||
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||||
|
"""Decode dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dl:
|
||||||
|
PyTorch's dataloader containing the dataset to decode.
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
decoding_graph:
|
||||||
|
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||||
|
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||||
|
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||||
|
Returns:
|
||||||
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
|
Its value is a list of tuples. Each tuple contains two elements:
|
||||||
|
The first is the reference transcript, and the second is the
|
||||||
|
predicted result.
|
||||||
|
"""
|
||||||
|
num_cuts = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_batches = len(dl)
|
||||||
|
except TypeError:
|
||||||
|
num_batches = "?"
|
||||||
|
|
||||||
|
if params.decoding_method == "greedy_search":
|
||||||
|
log_interval = 50
|
||||||
|
else:
|
||||||
|
log_interval = 20
|
||||||
|
|
||||||
|
results = defaultdict(list)
|
||||||
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
texts = batch["supervisions"]["text"]
|
||||||
|
texts = [list("".join(text.split())) for text in texts]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
|
hyps_dict = decode_one_batch(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
lexicon=lexicon,
|
||||||
|
graph_compiler=graph_compiler,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
batch=batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
for name, hyps in hyps_dict.items():
|
||||||
|
this_batch = []
|
||||||
|
assert len(hyps) == len(texts)
|
||||||
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
|
this_batch.append((cut_id, ref_text, hyp_words))
|
||||||
|
|
||||||
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
|
num_cuts += len(texts)
|
||||||
|
|
||||||
|
if batch_idx % log_interval == 0:
|
||||||
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
|
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def save_results(
|
||||||
|
params: AttributeDict,
|
||||||
|
test_set_name: str,
|
||||||
|
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||||
|
):
|
||||||
|
test_set_wers = dict()
|
||||||
|
for key, results in results_dict.items():
|
||||||
|
recog_path = (
|
||||||
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
|
results = sorted(results)
|
||||||
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
|
# ref/hyp pairs.
|
||||||
|
errs_filename = (
|
||||||
|
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
|
with open(errs_filename, "w") as f:
|
||||||
|
wer = write_error_stats(
|
||||||
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
|
)
|
||||||
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
|
errs_info = (
|
||||||
|
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
|
with open(errs_info, "w") as f:
|
||||||
|
print("settings\tWER", file=f)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
print("{}\t{}".format(key, val), file=f)
|
||||||
|
|
||||||
|
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||||
|
note = "\tbest for {}".format(test_set_name)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
s += "{}\t{}{}\n".format(key, val, note)
|
||||||
|
note = ""
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
AishellAsrDataModule.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
assert params.decoding_method in (
|
||||||
|
"greedy_search",
|
||||||
|
"beam_search",
|
||||||
|
"modified_beam_search",
|
||||||
|
"fast_beam_search",
|
||||||
|
"fast_beam_search_LG",
|
||||||
|
"fast_beam_search_nbest_oracle",
|
||||||
|
)
|
||||||
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||||
|
else:
|
||||||
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
|
if params.causal:
|
||||||
|
assert (
|
||||||
|
"," not in params.chunk_size
|
||||||
|
), "chunk_size should be one value in decoding."
|
||||||
|
assert (
|
||||||
|
"," not in params.left_context_frames
|
||||||
|
), "left_context_frames should be one value in decoding."
|
||||||
|
params.suffix += f"-chunk-{params.chunk_size}"
|
||||||
|
params.suffix += f"-left-context-{params.left_context_frames}"
|
||||||
|
|
||||||
|
if "fast_beam_search" in params.decoding_method:
|
||||||
|
params.suffix += f"-beam-{params.beam}"
|
||||||
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
|
if "nbest" in params.decoding_method:
|
||||||
|
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||||
|
params.suffix += f"-num-paths-{params.num_paths}"
|
||||||
|
if "LG" in params.decoding_method:
|
||||||
|
params.suffix += f"_ilme_scale_{params.ilme_scale}"
|
||||||
|
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||||
|
elif "beam_search" in params.decoding_method:
|
||||||
|
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
|
else:
|
||||||
|
params.suffix += f"-context-{params.context_size}"
|
||||||
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
|
params.suffix += f"-blank-penalty-{params.blank_penalty}"
|
||||||
|
|
||||||
|
if params.use_averaged_model:
|
||||||
|
params.suffix += "-use-averaged-model"
|
||||||
|
|
||||||
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
|
logging.info("Decoding started")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
params.blank_id = lexicon.token_table["<blk>"]
|
||||||
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
|
|
||||||
|
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||||
|
lexicon=lexicon,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_model(params)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
if "fast_beam_search" in params.decoding_method:
|
||||||
|
if "LG" in params.decoding_method:
|
||||||
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
lg_filename = params.lang_dir / "LG.pt"
|
||||||
|
logging.info(f"Loading {lg_filename}")
|
||||||
|
decoding_graph = k2.Fsa.from_dict(
|
||||||
|
torch.load(lg_filename, map_location=device)
|
||||||
|
)
|
||||||
|
decoding_graph.scores *= params.ngram_lm_scale
|
||||||
|
else:
|
||||||
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
|
else:
|
||||||
|
decoding_graph = None
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
|
aishell = AishellAsrDataModule(args)
|
||||||
|
|
||||||
|
def remove_short_utt(c: Cut):
|
||||||
|
T = ((c.num_frames - 7) // 2 + 1) // 2
|
||||||
|
if T <= 0:
|
||||||
|
logging.warning(
|
||||||
|
f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}."
|
||||||
|
)
|
||||||
|
return T > 0
|
||||||
|
|
||||||
|
dev_cuts = aishell.valid_cuts()
|
||||||
|
dev_cuts = dev_cuts.filter(remove_short_utt)
|
||||||
|
dev_dl = aishell.valid_dataloaders(dev_cuts)
|
||||||
|
|
||||||
|
test_cuts = aishell.test_cuts()
|
||||||
|
test_cuts = test_cuts.filter(remove_short_utt)
|
||||||
|
test_dl = aishell.test_dataloaders(test_cuts)
|
||||||
|
|
||||||
|
test_sets = ["dev", "test"]
|
||||||
|
test_dls = [dev_dl, test_dl]
|
||||||
|
|
||||||
|
for test_set, test_dl in zip(test_sets, test_dls):
|
||||||
|
results_dict = decode_dataset(
|
||||||
|
dl=test_dl,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
lexicon=lexicon,
|
||||||
|
graph_compiler=graph_compiler,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_results(
|
||||||
|
params=params,
|
||||||
|
test_set_name=test_set,
|
||||||
|
results_dict=results_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
1
egs/aishell/ASR/zipformer/decode_stream.py
Symbolic link
1
egs/aishell/ASR/zipformer/decode_stream.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/zipformer/decode_stream.py
|
1
egs/aishell/ASR/zipformer/decoder.py
Symbolic link
1
egs/aishell/ASR/zipformer/decoder.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/zipformer/decoder.py
|
1
egs/aishell/ASR/zipformer/encoder_interface.py
Symbolic link
1
egs/aishell/ASR/zipformer/encoder_interface.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../pruned_transducer_stateless2/encoder_interface.py
|
1
egs/aishell/ASR/zipformer/export-onnx-streaming.py
Symbolic link
1
egs/aishell/ASR/zipformer/export-onnx-streaming.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/zipformer/export-onnx-streaming.py
|
1
egs/aishell/ASR/zipformer/export-onnx.py
Symbolic link
1
egs/aishell/ASR/zipformer/export-onnx.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/zipformer/export-onnx.py
|
1
egs/aishell/ASR/zipformer/export.py
Symbolic link
1
egs/aishell/ASR/zipformer/export.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/zipformer/export.py
|
1
egs/aishell/ASR/zipformer/jit_pretrained.py
Symbolic link
1
egs/aishell/ASR/zipformer/jit_pretrained.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/zipformer/jit_pretrained.py
|
1
egs/aishell/ASR/zipformer/jit_pretrained_streaming.py
Symbolic link
1
egs/aishell/ASR/zipformer/jit_pretrained_streaming.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/zipformer/jit_pretrained_streaming.py
|
1
egs/aishell/ASR/zipformer/joiner.py
Symbolic link
1
egs/aishell/ASR/zipformer/joiner.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/zipformer/joiner.py
|
1
egs/aishell/ASR/zipformer/model.py
Symbolic link
1
egs/aishell/ASR/zipformer/model.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/zipformer/model.py
|
1
egs/aishell/ASR/zipformer/onnx_check.py
Symbolic link
1
egs/aishell/ASR/zipformer/onnx_check.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/zipformer/onnx_check.py
|
286
egs/aishell/ASR/zipformer/onnx_decode.py
Executable file
286
egs/aishell/ASR/zipformer/onnx_decode.py
Executable file
@ -0,0 +1,286 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||||
|
# Zengwei Yao,
|
||||||
|
# Xiaoyu Yang,
|
||||||
|
# Wei Kang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This script loads ONNX exported models and uses them to decode the test sets.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from asr_datamodule import AishellAsrDataModule
|
||||||
|
from lhotse.cut import Cut
|
||||||
|
from onnx_pretrained import OnnxModel, greedy_search
|
||||||
|
|
||||||
|
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the encoder onnx model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the decoder onnx model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--joiner-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the joiner onnx model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="pruned_transducer_stateless7/exp",
|
||||||
|
help="The experiment dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_char/tokens.txt",
|
||||||
|
help="Path to the tokens.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoding-method",
|
||||||
|
type=str,
|
||||||
|
default="greedy_search",
|
||||||
|
help="Valid values are greedy_search and modified_beam_search",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def decode_one_batch(
|
||||||
|
model: OnnxModel, token_table: k2.SymbolTable, batch: dict
|
||||||
|
) -> List[List[str]]:
|
||||||
|
"""Decode one batch and return the result.
|
||||||
|
Currently it only greedy_search is supported.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
token_table:
|
||||||
|
Mapping ids to tokens.
|
||||||
|
batch:
|
||||||
|
It is the return value from iterating
|
||||||
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
|
for the format of the `batch`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return the decoded results for each utterance.
|
||||||
|
"""
|
||||||
|
feature = batch["inputs"]
|
||||||
|
assert feature.ndim == 3
|
||||||
|
# at entry, feature is (N, T, C)
|
||||||
|
|
||||||
|
supervisions = batch["supervisions"]
|
||||||
|
feature_lens = supervisions["num_frames"].to(dtype=torch.int64)
|
||||||
|
|
||||||
|
encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens)
|
||||||
|
|
||||||
|
hyps = greedy_search(
|
||||||
|
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
|
||||||
|
)
|
||||||
|
|
||||||
|
hyps = [[token_table[h] for h in hyp] for hyp in hyps]
|
||||||
|
return hyps
|
||||||
|
|
||||||
|
|
||||||
|
def decode_dataset(
|
||||||
|
dl: torch.utils.data.DataLoader,
|
||||||
|
model: nn.Module,
|
||||||
|
token_table: k2.SymbolTable,
|
||||||
|
) -> Tuple[List[Tuple[str, List[str], List[str]]], float]:
|
||||||
|
"""Decode dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dl:
|
||||||
|
PyTorch's dataloader containing the dataset to decode.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
token_table:
|
||||||
|
Mapping ids to tokens.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- A list of tuples. Each tuple contains three elements:
|
||||||
|
- cut_id,
|
||||||
|
- reference transcript,
|
||||||
|
- predicted result.
|
||||||
|
- The total duration (in seconds) of the dataset.
|
||||||
|
"""
|
||||||
|
num_cuts = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_batches = len(dl)
|
||||||
|
except TypeError:
|
||||||
|
num_batches = "?"
|
||||||
|
|
||||||
|
log_interval = 10
|
||||||
|
total_duration = 0
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]])
|
||||||
|
|
||||||
|
hyps = decode_one_batch(model=model, token_table=token_table, batch=batch)
|
||||||
|
|
||||||
|
this_batch = []
|
||||||
|
assert len(hyps) == len(texts)
|
||||||
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
|
ref_words = list(ref_text)
|
||||||
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
|
results.extend(this_batch)
|
||||||
|
|
||||||
|
num_cuts += len(texts)
|
||||||
|
|
||||||
|
if batch_idx % log_interval == 0:
|
||||||
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
|
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||||
|
|
||||||
|
return results, total_duration
|
||||||
|
|
||||||
|
|
||||||
|
def save_results(
|
||||||
|
res_dir: Path,
|
||||||
|
test_set_name: str,
|
||||||
|
results: List[Tuple[str, List[str], List[str]]],
|
||||||
|
):
|
||||||
|
recog_path = res_dir / f"recogs-{test_set_name}.txt"
|
||||||
|
results = sorted(results)
|
||||||
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
|
# ref/hyp pairs.
|
||||||
|
errs_filename = res_dir / f"errs-{test_set_name}.txt"
|
||||||
|
with open(errs_filename, "w") as f:
|
||||||
|
wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True)
|
||||||
|
|
||||||
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
|
errs_info = res_dir / f"wer-summary-{test_set_name}.txt"
|
||||||
|
with open(errs_info, "w") as f:
|
||||||
|
print("WER", file=f)
|
||||||
|
print(wer, file=f)
|
||||||
|
|
||||||
|
s = "\nFor {}, WER is {}:\n".format(test_set_name, wer)
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
AishellAsrDataModule.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
args.decoding_method == "greedy_search"
|
||||||
|
), "Only supports greedy_search currently."
|
||||||
|
res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}"
|
||||||
|
|
||||||
|
setup_logger(f"{res_dir}/log-decode")
|
||||||
|
logging.info("Decoding started")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
|
token_table = k2.SymbolTable.from_file(args.tokens)
|
||||||
|
assert token_table[0] == "<blk>"
|
||||||
|
|
||||||
|
logging.info(vars(args))
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = OnnxModel(
|
||||||
|
encoder_model_filename=args.encoder_model_filename,
|
||||||
|
decoder_model_filename=args.decoder_model_filename,
|
||||||
|
joiner_model_filename=args.joiner_model_filename,
|
||||||
|
)
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
|
|
||||||
|
aishell = AishellAsrDataModule(args)
|
||||||
|
|
||||||
|
def remove_short_utt(c: Cut):
|
||||||
|
T = ((c.num_frames - 7) // 2 + 1) // 2
|
||||||
|
if T <= 0:
|
||||||
|
logging.warning(
|
||||||
|
f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}."
|
||||||
|
)
|
||||||
|
return T > 0
|
||||||
|
|
||||||
|
dev_cuts = aishell.valid_cuts()
|
||||||
|
dev_cuts = dev_cuts.filter(remove_short_utt)
|
||||||
|
dev_dl = aishell.valid_dataloaders(dev_cuts)
|
||||||
|
|
||||||
|
test_cuts = aishell.test_net_cuts()
|
||||||
|
test_cuts = test_cuts.filter(remove_short_utt)
|
||||||
|
test_dl = aishell.test_dataloaders(test_cuts)
|
||||||
|
|
||||||
|
test_sets = ["dev", "test"]
|
||||||
|
test_dl = [dev_dl, test_dl]
|
||||||
|
|
||||||
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
|
start_time = time.time()
|
||||||
|
results, total_duration = decode_dataset(
|
||||||
|
dl=test_dl, model=model, token_table=token_table
|
||||||
|
)
|
||||||
|
end_time = time.time()
|
||||||
|
elapsed_seconds = end_time - start_time
|
||||||
|
rtf = elapsed_seconds / total_duration
|
||||||
|
|
||||||
|
logging.info(f"Elapsed time: {elapsed_seconds:.3f} s")
|
||||||
|
logging.info(f"Wave duration: {total_duration:.3f} s")
|
||||||
|
logging.info(
|
||||||
|
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
save_results(res_dir=res_dir, test_set_name=test_set, results=results)
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
1
egs/aishell/ASR/zipformer/onnx_pretrained-streaming.py
Symbolic link
1
egs/aishell/ASR/zipformer/onnx_pretrained-streaming.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/zipformer/onnx_pretrained-streaming.py
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user