Add ONNX support for Zipformer and ConvEmformer (#884)

This commit is contained in:
Fangjun Kuang 2023-02-09 00:02:38 +08:00 committed by GitHub
parent af735eb75b
commit 2b995639b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 3806 additions and 1606 deletions

View File

@ -1,79 +0,0 @@
#!/usr/bin/env bash
#
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05
log "Downloading pre-trained model from $repo_url"
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt"
git lfs pull --include "data/lang_bpe_500/bpe.model"
cd exp
ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-99.pt
popd
log "Display test files"
tree $repo/
soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
log "Install ncnn and pnnx"
# We are using a modified ncnn here. Will try to merge it to the official repo
# of ncnn
git clone https://github.com/csukuangfj/ncnn
pushd ncnn
git submodule init
git submodule update python/pybind11
python3 setup.py bdist_wheel
ls -lh dist/
pip install dist/*.whl
cd tools/pnnx
mkdir build
cd build
cmake -D Python3_EXECUTABLE=/opt/hostedtoolcache/Python/3.8.14/x64/bin/python3 ..
make -j4 pnnx
./src/pnnx || echo "pass"
popd
log "Test exporting to pnnx format"
./conv_emformer_transducer_stateless2/export-for-ncnn.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
\
--num-encoder-layers 12 \
--chunk-length 32 \
--cnn-module-kernel 31 \
--left-context-length 32 \
--right-context-length 8 \
--memory-size 32
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt
./conv_emformer_transducer_stateless2/streaming-ncnn-decode.py \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
--encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
--decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
--decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
--joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
--joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
$repo/test_wavs/1089-134686-0001.wav

View File

@ -28,63 +28,6 @@ ln -s pretrained-iter-468000-avg-16.pt pretrained.pt
ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt
popd popd
log "Install ncnn and pnnx"
# We are using a modified ncnn here. Will try to merge it to the official repo
# of ncnn
git clone https://github.com/csukuangfj/ncnn
pushd ncnn
git submodule init
git submodule update python/pybind11
python3 setup.py bdist_wheel
ls -lh dist/
pip install dist/*.whl
cd tools/pnnx
mkdir build
cd build
cmake ..
make -j4 pnnx
./src/pnnx || echo "pass"
popd
log "Test exporting to pnnx format"
./lstm_transducer_stateless2/export.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
--pnnx 1
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt
./lstm_transducer_stateless2/ncnn-decode.py \
--bpe-model-filename $repo/data/lang_bpe_500/bpe.model \
--encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
--encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
--decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
--decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
--joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
--joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
$repo/test_wavs/1089-134686-0001.wav
./lstm_transducer_stateless2/streaming-ncnn-decode.py \
--bpe-model-filename $repo/data/lang_bpe_500/bpe.model \
--encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
--encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
--decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
--decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
--joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
--joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
$repo/test_wavs/1089-134686-0001.wav
log "Test exporting with torch.jit.trace()" log "Test exporting with torch.jit.trace()"
./lstm_transducer_stateless2/export.py \ ./lstm_transducer_stateless2/export.py \
@ -106,47 +49,6 @@ log "Decode with models exported by torch.jit.trace()"
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav
log "Test exporting to ONNX"
./lstm_transducer_stateless2/export.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
--onnx 1
log "Decode with ONNX models "
./lstm_transducer_stateless2/streaming-onnx-decode.py \
--bpe-model-filename $repo/data/lang_bpe_500/bpe.model \
--encoder-model-filename $repo//exp/encoder.onnx \
--decoder-model-filename $repo/exp/decoder.onnx \
--joiner-model-filename $repo/exp/joiner.onnx \
--joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \
--joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \
$repo/test_wavs/1089-134686-0001.wav
./lstm_transducer_stateless2/streaming-onnx-decode.py \
--bpe-model-filename $repo/data/lang_bpe_500/bpe.model \
--encoder-model-filename $repo//exp/encoder.onnx \
--decoder-model-filename $repo/exp/decoder.onnx \
--joiner-model-filename $repo/exp/joiner.onnx \
--joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \
--joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \
$repo/test_wavs/1221-135766-0001.wav
./lstm_transducer_stateless2/streaming-onnx-decode.py \
--bpe-model-filename $repo/data/lang_bpe_500/bpe.model \
--encoder-model-filename $repo//exp/encoder.onnx \
--decoder-model-filename $repo/exp/decoder.onnx \
--joiner-model-filename $repo/exp/joiner.onnx \
--joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \
--joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \
$repo/test_wavs/1221-135766-0002.wav
for sym in 1 2 3; do for sym in 1 2 3; do
log "Greedy search with --max-sym-per-frame $sym" log "Greedy search with --max-sym-per-frame $sym"

View File

@ -30,15 +30,6 @@ ln -s pretrained.pt epoch-99.pt
ls -lh *.pt ls -lh *.pt
popd popd
log "Test exporting to ONNX format"
./pruned_transducer_stateless7/export.py \
--exp-dir $repo/exp \
--use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--onnx 1
log "Export to torchscript model" log "Export to torchscript model"
./pruned_transducer_stateless7/export.py \ ./pruned_transducer_stateless7/export.py \
--exp-dir $repo/exp \ --exp-dir $repo/exp \
@ -50,27 +41,6 @@ log "Export to torchscript model"
ls -lh $repo/exp/*.pt ls -lh $repo/exp/*.pt
log "Decode with ONNX models"
./pruned_transducer_stateless7/onnx_check.py \
--jit-filename $repo/exp/cpu_jit.pt \
--onnx-encoder-filename $repo/exp/encoder.onnx \
--onnx-decoder-filename $repo/exp/decoder.onnx \
--onnx-joiner-filename $repo/exp/joiner.onnx \
--onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \
--onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx
./pruned_transducer_stateless7/onnx_pretrained.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--encoder-model-filename $repo/exp/encoder.onnx \
--decoder-model-filename $repo/exp/decoder.onnx \
--joiner-model-filename $repo/exp/joiner.onnx \
--joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \
--joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
log "Decode with models exported by torch.jit.script()" log "Decode with models exported by torch.jit.script()"
./pruned_transducer_stateless7/jit_pretrained.py \ ./pruned_transducer_stateless7/jit_pretrained.py \

View File

@ -34,16 +34,6 @@ ln -s pretrained.pt epoch-99.pt
ls -lh *.pt ls -lh *.pt
popd popd
log "Test exporting to ONNX format"
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir $repo/exp \
--use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--fp16 \
--onnx 1
log "Export to torchscript model" log "Export to torchscript model"
./pruned_transducer_stateless7_streaming/export.py \ ./pruned_transducer_stateless7_streaming/export.py \
--exp-dir $repo/exp \ --exp-dir $repo/exp \

133
.github/scripts/test-ncnn-export.sh vendored Executable file
View File

@ -0,0 +1,133 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/librispeech/ASR
log "Install ncnn and pnnx"
# We are using a modified ncnn here. Will try to merge it to the official repo
# of ncnn
git clone https://github.com/csukuangfj/ncnn
pushd ncnn
git submodule init
git submodule update python/pybind11
python3 setup.py bdist_wheel
ls -lh dist/
pip install dist/*.whl
cd tools/pnnx
mkdir build
cd build
echo "which python3"
which python3
#/opt/hostedtoolcache/Python/3.8.16/x64/bin/python3
cmake -D Python3_EXECUTABLE=$(which python3) ..
make -j4 pnnx
./src/pnnx || echo "pass"
popd
log "=========================================================================="
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt"
cd exp
ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-99.pt
popd
log "Export via torch.jit.trace()"
./conv_emformer_transducer_stateless2/export-for-ncnn.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
\
--num-encoder-layers 12 \
--chunk-length 32 \
--cnn-module-kernel 31 \
--left-context-length 32 \
--right-context-length 8 \
--memory-size 32
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt
python3 ./conv_emformer_transducer_stateless2/streaming-ncnn-decode.py \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
--encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
--decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
--decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
--joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
--joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
$repo/test_wavs/1089-134686-0001.wav
rm -rf $repo
log "--------------------------------------------------------------------------"
log "=========================================================================="
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt"
cd exp
ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt
popd
log "Export via torch.jit.trace()"
./lstm_transducer_stateless2/export-for-ncnn.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--use-averaged-model 0
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt
python3 ./lstm_transducer_stateless2/streaming-ncnn-decode.py \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
--encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
--decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
--decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
--joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
--joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
$repo/test_wavs/1089-134686-0001.wav
python3 ./lstm_transducer_stateless2/ncnn-decode.py \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
--encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
--decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
--decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
--joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
--joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
$repo/test_wavs/1089-134686-0001.wav
rm -rf $repo
log "--------------------------------------------------------------------------"

View File

@ -10,6 +10,8 @@ log() {
cd egs/librispeech/ASR cd egs/librispeech/ASR
log "==========================================================================" log "=========================================================================="
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
log "Downloading pre-trained model from $repo_url" log "Downloading pre-trained model from $repo_url"
@ -192,3 +194,158 @@ log "Run onnx_pretrained.py"
rm -rf $repo rm -rf $repo
log "--------------------------------------------------------------------------" log "--------------------------------------------------------------------------"
log "=========================================================================="
repo_url=
rm -rf $repo
log "--------------------------------------------------------------------------"
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
log "Export via torch.jit.script()"
./pruned_transducer_stateless7/export.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--feedforward-dims "1024,1024,2048,2048,1024" \
--jit 1
log "Test exporting to ONNX format"
./pruned_transducer_stateless7/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--feedforward-dims "1024,1024,2048,2048,1024"
ls -lh $repo/exp
log "Run onnx_check.py"
./pruned_transducer_stateless7/onnx_check.py \
--jit-filename $repo/exp/cpu_jit.pt \
--onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx
log "Run onnx_pretrained.py"
./pruned_transducer_stateless7/onnx_pretrained.py \
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
log "=========================================================================="
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt"
cd exp
ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-99.pt
popd
log "Test exporting to ONNX format"
./conv_emformer_transducer_stateless2/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--num-encoder-layers 12 \
--chunk-length 32 \
--cnn-module-kernel 31 \
--left-context-length 32 \
--right-context-length 8 \
--memory-size 32
log "Run onnx_pretrained.py"
./conv_emformer_transducer_stateless2/onnx_pretrained.py \
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1221-135766-0001.wav
rm -rf $repo
log "--------------------------------------------------------------------------"
log "=========================================================================="
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt"
cd exp
ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt
popd
log "Export via torch.jit.trace()"
./lstm_transducer_stateless2/export.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp/ \
--jit-trace 1
log "Test exporting to ONNX format"
./lstm_transducer_stateless2/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp
ls -lh $repo/exp
log "Run onnx_check.py"
./lstm_transducer_stateless2/onnx_check.py \
--jit-encoder-filename $repo/exp/encoder_jit_trace.pt \
--jit-decoder-filename $repo/exp/decoder_jit_trace.pt \
--jit-joiner-filename $repo/exp/joiner_jit_trace.pt \
--onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx
log "Run onnx_pretrained.py"
./lstm_transducer_stateless2/onnx_pretrained.py \
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1221-135766-0001.wav
rm -rf $repo
log "--------------------------------------------------------------------------"

View File

@ -39,7 +39,7 @@ concurrency:
jobs: jobs:
run_librispeech_2022_11_11_zipformer: run_librispeech_2022_11_11_zipformer:
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
strategy: strategy:
matrix: matrix:

View File

@ -39,7 +39,7 @@ concurrency:
jobs: jobs:
run_librispeech_2022_12_29_zipformer_streaming: run_librispeech_2022_12_29_zipformer_streaming:
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'streaming-zipformer' || github.event_name == 'push' || github.event_name == 'schedule' if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'streaming-zipformer' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
strategy: strategy:
matrix: matrix:

View File

@ -22,7 +22,7 @@ concurrency:
jobs: jobs:
run_librispeech_lstm_transducer_stateless2_2022_09_03: run_librispeech_lstm_transducer_stateless2_2022_09_03:
if: github.event.label.name == 'ready' || github.event.label.name == 'LODR' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'ncnn' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule' if: github.event.label.name == 'ready' || github.event.label.name == 'LODR' || github.event.label.name == 'shallow-fusion' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
strategy: strategy:
matrix: matrix:

View File

@ -1,4 +1,4 @@
name: run-librispeech-conv-emformer-transducer-stateless2-2022-12-05 name: test-ncnn-export
on: on:
push: push:
@ -16,15 +16,18 @@ on:
# nightly build at 15:50 UTC time every day # nightly build at 15:50 UTC time every day
- cron: "50 15 * * *" - cron: "50 15 * * *"
concurrency:
group: test_ncnn_export-${{ github.ref }}
cancel-in-progress: true
jobs: jobs:
run_librispeech_conv_emformer_transducer_stateless2_2022_12_05: test_ncnn_export:
if: github.event.label.name == 'ready' || github.event.label.name == 'ncnn' || github.event_name == 'push' || github.event_name == 'schedule' if: github.event.label.name == 'ready' || github.event.label.name == 'ncnn' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
python-version: [3.8] python-version: [3.8]
fail-fast: false fail-fast: false
steps: steps:
@ -41,7 +44,7 @@ jobs:
- name: Install Python dependencies - name: Install Python dependencies
run: | run: |
grep -v '^#' ./requirements-ci.txt | grep -v kaldifst | xargs -n 1 -L 1 pip install grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf pip uninstall -y protobuf
pip install --no-binary protobuf protobuf pip install --no-binary protobuf protobuf
@ -59,19 +62,14 @@ jobs:
run: | run: |
.github/scripts/install-kaldifeat.sh .github/scripts/install-kaldifeat.sh
- name: Inference with pre-trained model - name: Test ncnn export
shell: bash shell: bash
env: env:
GITHUB_EVENT_NAME: ${{ github.event_name }} GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
run: | run: |
mkdir -p egs/librispeech/ASR/data
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
sudo apt-get -qq install git-lfs tree sox
export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
.github/scripts/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.sh .github/scripts/test-ncnn-export.sh

View File

@ -1,69 +1,78 @@
Export to ONNX Export to ONNX
============== ==============
In this section, we describe how to export models to ONNX. In this section, we describe how to export the following models to ONNX.
In each recipe, there is a file called ``export-onnx.py``, which is used
to export trained models to ONNX.
There is also a file named ``onnx_pretrained.py``, which you can use
the exported ONNX model in Python to decode sound files.
Example
=======
In the following, we demonstrate how to export a streaming Zipformer pre-trained
model from `<python3 ./python-api-examples/speech-recognition-from-microphone.py>`_
to ONNX.
Download the pre-trained model
------------------------------
.. hint:: .. hint::
Only non-streaming conformer transducer models are tested. We assume you have installed `git-lfs`_.
When to use it
--------------
It you want to use an inference framework that supports ONNX
to run the pretrained model.
How to export
-------------
We use
`<https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless3>`_
as an example in the following.
.. code-block:: bash .. code-block:: bash
cd egs/librispeech/ASR cd egs/librispeech/ASR
epoch=14
avg=2
./pruned_transducer_stateless3/export.py \ repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
--exp-dir ./pruned_transducer_stateless3/exp \ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
--bpe-model data/lang_bpe_500/bpe.model \ repo=$(basename $repo_url)
--epoch $epoch \
--avg $avg \
--onnx 1
It will generate the following files inside ``pruned_transducer_stateless3/exp``: pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
- ``encoder.onnx`` Export the model to ONNX
- ``decoder.onnx`` ------------------------
- ``joiner.onnx``
- ``joiner_encoder_proj.onnx``
- ``joiner_decoder_proj.onnx``
You can use ``./pruned_transducer_stateless3/exp/onnx_pretrained.py`` to decode
waves with the generated files:
.. code-block:: bash .. code-block:: bash
./pruned_transducer_stateless3/onnx_pretrained.py \ ./pruned_transducer_stateless7_streaming/export-onnx.py \
--bpe-model ./data/lang_bpe_500/bpe.model \ --bpe-model $repo/data/lang_bpe_500/bpe.model \
--encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \ --use-averaged-model 0 \
--decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \ --epoch 99 \
--joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \ --avg 1 \
--joiner-encoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_encoder_proj.onnx \ --decode-chunk-len 32 \
--joiner-decoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_decoder_proj.onnx \ --exp-dir $repo/exp/
/path/to/foo.wav \
/path/to/bar.wav \
/path/to/baz.wav
.. warning::
How to use the exported model ``export-onnx.py`` from different recipes has different options.
-----------------------------
We also provide `<https://github.com/k2-fsa/sherpa-onnx>`_ In the above example, ``--decode-chunk-len`` is specific for the
performing speech recognition using `onnxruntime <https://github.com/microsoft/onnxruntime>`_ streaming Zipformer. Other models won't have such an option.
with exported models.
It has been tested on Linux, macOS, and Windows. It will generate the following 3 files in ``$repo/exp``
- ``encoder-epoch-99-avg-1.onnx``
- ``decoder-epoch-99-avg-1.onnx``
- ``joiner-epoch-99-avg-1.onnx``
Decode sound files with exported ONNX models
--------------------------------------------
.. code-block:: bash
./pruned_transducer_stateless7_streaming/onnx_pretrained.py \
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav

View File

@ -580,12 +580,11 @@ for ``pnnx``:
iter=468000 iter=468000
avg=16 avg=16
./lstm_transducer_stateless2/export.py \ ./lstm_transducer_stateless2/export-for-ncnn.py \
--exp-dir ./lstm_transducer_stateless2/exp \ --exp-dir ./lstm_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --bpe-model data/lang_bpe_500/bpe.model \
--iter $iter \ --iter $iter \
--avg $avg \ --avg $avg
--pnnx 1
It will generate 3 files: It will generate 3 files:
@ -615,7 +614,7 @@ To use the above generated files, run:
.. code-block:: bash .. code-block:: bash
./lstm_transducer_stateless2/ncnn-decode.py \ ./lstm_transducer_stateless2/ncnn-decode.py \
--bpe-model-filename ./data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param \ --encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param \
--encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin \ --encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin \
--decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param \ --decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param \
@ -627,7 +626,7 @@ To use the above generated files, run:
.. code-block:: bash .. code-block:: bash
./lstm_transducer_stateless2/streaming-ncnn-decode.py \ ./lstm_transducer_stateless2/streaming-ncnn-decode.py \
--bpe-model-filename ./data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param \ --encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param \
--encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin \ --encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin \
--decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param \ --decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param \
@ -657,6 +656,3 @@ by visiting the following links:
You can find more usages of the pretrained models in You can find more usages of the pretrained models in
`<https://k2-fsa.github.io/sherpa/python/streaming_asr/lstm/index.html>`_ `<https://k2-fsa.github.io/sherpa/python/streaming_asr/lstm/index.html>`_
Export ConvEmformer transducer models for ncnn
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -169,9 +169,11 @@ class ConvolutionModule(nn.Module):
channels: int, channels: int,
kernel_size: int, kernel_size: int,
bias: bool = True, bias: bool = True,
is_pnnx: bool = True,
) -> None: ) -> None:
"""Construct an ConvolutionModule object.""" """Construct an ConvolutionModule object."""
super().__init__() super().__init__()
self.is_pnnx = is_pnnx
# kernerl_size should be an odd number for 'SAME' padding # kernerl_size should be an odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0, kernel_size assert (kernel_size - 1) % 2 == 0, kernel_size
@ -383,8 +385,10 @@ class ConvolutionModule(nn.Module):
- output right_context of shape (R, B, D). - output right_context of shape (R, B, D).
- updated cache tensor of shape (B, D, cache_size). - updated cache tensor of shape (B, D, cache_size).
""" """
# U, B, D = utterance.size() if self.is_pnnx is False:
# R, _, _ = right_context.size() U, B, D = utterance.size()
R, _, _ = right_context.size()
else:
U = self.chunk_length U = self.chunk_length
B = 1 B = 1
D = self.channels D = self.channels
@ -448,8 +452,10 @@ class EmformerAttention(nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
tanh_on_mem: bool = False, tanh_on_mem: bool = False,
negative_inf: float = -1e8, negative_inf: float = -1e8,
is_pnnx: bool = True,
): ):
super().__init__() super().__init__()
self.is_pnnx = is_pnnx
if embed_dim % nhead != 0: if embed_dim % nhead != 0:
raise ValueError( raise ValueError(
@ -539,10 +545,11 @@ class EmformerAttention(nn.Module):
left_context_val: Optional[torch.Tensor] = None, left_context_val: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Underlying chunk-wise attention implementation.""" """Underlying chunk-wise attention implementation."""
# U, B, _ = utterance.size() if self.is_pnnx is False:
# R = right_context.size(0) U, B, _ = utterance.size()
# M = memory.size(0) R = right_context.size(0)
M = memory.size(0)
else:
U = self.chunk_length U = self.chunk_length
B = 1 B = 1
R = self.right_context_length R = self.right_context_length
@ -570,21 +577,29 @@ class EmformerAttention(nn.Module):
# KV = key.size(0) # KV = key.size(0)
if self.is_pnnx is True:
reshaped_query = query.view(Q, self.nhead, self.head_dim).permute(1, 0, 2) reshaped_query = query.view(Q, self.nhead, self.head_dim).permute(1, 0, 2)
reshaped_key = key.view(M + R + U + L, self.nhead, self.head_dim).permute( reshaped_key = key.view(M + R + U + L, self.nhead, self.head_dim).permute(
1, 0, 2 1, 0, 2
) )
reshaped_value = value.view(M + R + U + L, self.nhead, self.head_dim).permute( reshaped_value = value.view(
1, 0, 2 M + R + U + L, self.nhead, self.head_dim
) ).permute(1, 0, 2)
else:
# reshaped_query, reshaped_key, reshaped_value = [ reshaped_query, reshaped_key, reshaped_value = [
# tensor.contiguous().view(-1, B * self.nhead, self.head_dim).transpose(0, 1) tensor.contiguous()
# for tensor in [query, key, value] .view(-1, B * self.nhead, self.head_dim)
# ] # (B * nhead, Q or KV, head_dim) .transpose(0, 1)
for tensor in [query, key, value]
] # (B * nhead, Q or KV, head_dim)
if self.is_pnnx is True:
attention_weights = torch.bmm( attention_weights = torch.bmm(
reshaped_query * scaling, reshaped_key.permute(0, 2, 1) reshaped_query * scaling, reshaped_key.permute(0, 2, 1)
) # (B * nhead, Q, KV) ) # (B * nhead, Q, KV)
else:
attention_weights = torch.bmm(
reshaped_query * scaling, reshaped_key.transpose(1, 2)
) # (B * nhead, Q, KV)
# compute attention probabilities # compute attention probabilities
if False: if False:
@ -597,10 +612,15 @@ class EmformerAttention(nn.Module):
# compute attention outputs # compute attention outputs
attention = torch.bmm(attention_probs, reshaped_value) attention = torch.bmm(attention_probs, reshaped_value)
assert attention.shape == (B * self.nhead, Q, self.head_dim) assert attention.shape == (B * self.nhead, Q, self.head_dim)
if self.is_pnnx is True:
attention = attention.permute(1, 0, 2).reshape(-1, self.embed_dim) attention = attention.permute(1, 0, 2).reshape(-1, self.embed_dim)
# TODO(fangjun): ncnn does not support reshape(-1, 1, self.embed_dim) # TODO(fangjun): ncnn does not support reshape(-1, 1, self.embed_dim)
# We have to change InnerProduct in ncnn to ignore the extra dim below # We have to change InnerProduct in ncnn to ignore the extra dim below
attention = attention.unsqueeze(1) attention = attention.unsqueeze(1)
else:
attention = (
attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim)
)
# apply output projection # apply output projection
output_right_context_utterance = self.out_proj(attention) output_right_context_utterance = self.out_proj(attention)
@ -733,11 +753,12 @@ class EmformerAttention(nn.Module):
- attention value of left context and utterance, which would be - attention value of left context and utterance, which would be
cached for next computation, with shape (L + U, B, D). cached for next computation, with shape (L + U, B, D).
""" """
# U = utterance.size(0) if self.is_pnnx is False:
# R = right_context.size(0) U = utterance.size(0)
# L = left_context_key.size(0) R = right_context.size(0)
# M = memory.size(0) L = left_context_key.size(0)
M = memory.size(0)
else:
U = self.chunk_length U = self.chunk_length
R = self.right_context_length R = self.right_context_length
L = self.left_context_length L = self.left_context_length
@ -811,6 +832,7 @@ class EmformerEncoderLayer(nn.Module):
memory_size: int = 0, memory_size: int = 0,
tanh_on_mem: bool = False, tanh_on_mem: bool = False,
negative_inf: float = -1e8, negative_inf: float = -1e8,
is_pnnx: bool = True,
): ):
super().__init__() super().__init__()
@ -824,6 +846,7 @@ class EmformerEncoderLayer(nn.Module):
dropout=dropout, dropout=dropout,
tanh_on_mem=tanh_on_mem, tanh_on_mem=tanh_on_mem,
negative_inf=negative_inf, negative_inf=negative_inf,
is_pnnx=is_pnnx,
) )
self.summary_op = nn.AvgPool1d( self.summary_op = nn.AvgPool1d(
kernel_size=chunk_length, stride=chunk_length, ceil_mode=True kernel_size=chunk_length, stride=chunk_length, ceil_mode=True
@ -850,6 +873,7 @@ class EmformerEncoderLayer(nn.Module):
right_context_length, right_context_length,
d_model, d_model,
cnn_module_kernel, cnn_module_kernel,
is_pnnx=is_pnnx,
) )
self.norm_final = BasicNorm(d_model) self.norm_final = BasicNorm(d_model)
@ -1204,6 +1228,7 @@ class EmformerEncoder(nn.Module):
memory_size: int = 0, memory_size: int = 0,
tanh_on_mem: bool = False, tanh_on_mem: bool = False,
negative_inf: float = -1e8, negative_inf: float = -1e8,
is_pnnx: bool = True,
): ):
super().__init__() super().__init__()
@ -1229,6 +1254,7 @@ class EmformerEncoder(nn.Module):
memory_size=memory_size, memory_size=memory_size,
tanh_on_mem=tanh_on_mem, tanh_on_mem=tanh_on_mem,
negative_inf=negative_inf, negative_inf=negative_inf,
is_pnnx=is_pnnx,
) )
for layer_idx in range(num_encoder_layers) for layer_idx in range(num_encoder_layers)
] ]
@ -1561,6 +1587,20 @@ class Emformer(EncoderInterface):
self.encoder_embed = Conv2dSubsampling(num_features, d_model, is_pnnx=is_pnnx) self.encoder_embed = Conv2dSubsampling(num_features, d_model, is_pnnx=is_pnnx)
self.is_pnnx = is_pnnx self.is_pnnx = is_pnnx
self.num_encoder_layers = num_encoder_layers
self.memory_size = memory_size
self.d_model = d_model
self.cnn_module_kernel = cnn_module_kernel
self.left_context_length = left_context_length // subsampling_factor
self.right_context_length = right_context_length
self.subsampling_factor = subsampling_factor
assert subsampling_factor == 4, subsampling_factor
pad_length = right_context_length + 2 * 4 + 3
# before subsampling
self.T = self.chunk_length + pad_length
self.encoder = EmformerEncoder( self.encoder = EmformerEncoder(
chunk_length=chunk_length // subsampling_factor, chunk_length=chunk_length // subsampling_factor,
d_model=d_model, d_model=d_model,
@ -1575,6 +1615,7 @@ class Emformer(EncoderInterface):
memory_size=memory_size, memory_size=memory_size,
tanh_on_mem=tanh_on_mem, tanh_on_mem=tanh_on_mem,
negative_inf=negative_inf, negative_inf=negative_inf,
is_pnnx=is_pnnx,
) )
def _forward( def _forward(
@ -1691,7 +1732,7 @@ class Conv2dSubsampling(nn.Module):
layer1_channels: int = 8, layer1_channels: int = 8,
layer2_channels: int = 32, layer2_channels: int = 32,
layer3_channels: int = 128, layer3_channels: int = 128,
is_pnnx: bool = False, is_pnnx: bool = True,
) -> None: ) -> None:
""" """
Args: Args:
@ -1767,7 +1808,7 @@ class Conv2dSubsampling(nn.Module):
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
x = self.conv(x) x = self.conv(x)
if torch.jit.is_tracing() and self.is_pnnx: if torch.jit.is_tracing() and self.is_pnnx is True:
x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim) x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim)
x = self.out(x) x = self.out(x)
else: else:

View File

@ -1,6 +1,10 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
Please see
https://k2-fsa.github.io/icefall/model-export/export-ncnn.html
for more details about how to use this file.
Usage: Usage:
./conv_emformer_transducer_stateless2/export-for-ncnn.py \ ./conv_emformer_transducer_stateless2/export-for-ncnn.py \
--exp-dir ./conv_emformer_transducer_stateless2/exp \ --exp-dir ./conv_emformer_transducer_stateless2/exp \
@ -44,7 +48,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import str2bool from icefall.utils import setup_logger, str2bool
def get_parser(): def get_parser():
@ -96,14 +100,6 @@ def get_parser():
help="Path to the BPE model", help="Path to the BPE model",
) )
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
""",
)
parser.add_argument( parser.add_argument(
"--context-size", "--context-size",
type=int, type=int,
@ -217,6 +213,8 @@ def main():
device = torch.device("cpu") device = torch.device("cpu")
setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn")
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
@ -330,5 +328,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

View File

@ -0,0 +1,644 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
"""
This script exports a transducer model from PyTorch to ONNX.
We use the pre-trained model from
https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt"
cd exp
ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-99.pt
popd
2. Export the model to ONNX
./conv_emformer_transducer_stateless2/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--num-encoder-layers 12 \
--chunk-length 32 \
--cnn-module-kernel 31 \
--left-context-length 32 \
--right-context-length 8 \
--memory-size 32
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
See ./onnx_pretrained.py for how to
use the exported ONNX models.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict, Tuple
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from decoder import Decoder
from scaling_converter import convert_scaled_to_non_scaled
from train2 import add_model_arguments, get_params, get_transducer_model
from emformer import Emformer
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless5/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
onnx.save(model, filename)
class OnnxEncoder(nn.Module):
"""A wrapper for Emformer and the encoder_proj from the joiner"""
def __init__(self, encoder: Emformer, encoder_proj: nn.Linear):
"""
Args:
encoder:
A Emformer encoder.
encoder_proj:
The projection layer for encoder from the joiner.
"""
super().__init__()
self.encoder = encoder
self.encoder_proj = encoder_proj
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Please see the help information of Emformer.forward
Args:
x:
A 3-D tensor of shape (N, T, C)
x_lens:
A 1-D tensor of shape (N,). Its dtype is torch.int64
Returns:
Return a tuple containing:
- encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
- encoder_out_lens, A 1-D tensor of shape (N,)
"""
encoder_out, encoder_out_lens = self.encoder(x, x_lens)
encoder_out = self.encoder_proj(encoder_out)
# Now encoder_out is of shape (N, T, joiner_dim)
return encoder_out, encoder_out_lens
class OnnxDecoder(nn.Module):
"""A wrapper for Decoder and the decoder_proj from the joiner"""
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
super().__init__()
self.decoder = decoder
self.decoder_proj = decoder_proj
def forward(self, y: torch.Tensor) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, context_size).
Returns
Return a 2-D tensor of shape (N, joiner_dim)
"""
need_pad = False
decoder_output = self.decoder(y, need_pad=need_pad)
decoder_output = decoder_output.squeeze(1)
output = self.decoder_proj(decoder_output)
return output
class OnnxJoiner(nn.Module):
"""A wrapper for the joiner"""
def __init__(self, output_linear: nn.Linear):
super().__init__()
self.output_linear = output_linear
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
) -> torch.Tensor:
"""
Args:
encoder_out:
A 2-D tensor of shape (N, joiner_dim)
decoder_out:
A 2-D tensor of shape (N, joiner_dim)
Returns:
Return a 2-D tensor of shape (N, vocab_size)
"""
logit = encoder_out + decoder_out
logit = self.output_linear(torch.tanh(logit))
return logit
def export_encoder_model_onnx(
encoder_model: OnnxEncoder,
encoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the given encoder model to ONNX format.
The exported model has the following inputs:
- x, a tensor of shape (N, T, C); dtype is torch.float32
- a list of states (each layers has 4 states)
and it has two outputs:
- encoder_out, a tensor of shape (N, T', joiner_dim)
- a list of states (each layers has 4 states)
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
num_encoder_layers = encoder_model.encoder.num_encoder_layers
memory_size = encoder_model.encoder.memory_size
cnn_module_kernel = encoder_model.encoder.cnn_module_kernel
chunk_length = encoder_model.encoder.chunk_length
right_context_length = encoder_model.encoder.right_context_length
encoder_dim = encoder_model.encoder.d_model
left_context_length = encoder_model.encoder.left_context_length
T = encoder_model.encoder.T
logging.info(f"num_encoder_layers={num_encoder_layers}")
logging.info(f"memory_size={memory_size}")
logging.info(f"cnn_module_kernel={cnn_module_kernel}")
logging.info(f"chunk_length={chunk_length}")
logging.info(f"right_context_length={right_context_length}")
logging.info(f"encoder_dim={encoder_dim}")
logging.info(f"left_context_length={left_context_length} (after subsampling)")
logging.info(f"T={T}")
meta_data = {
"model_type": "conv-emformer",
"version": "1",
"model_author": "k2-fsa",
"decode_chunk_len": str(chunk_length), # 32
"T": str(T), # 32
"num_encoder_layers": str(num_encoder_layers),
"memory_size": str(memory_size),
"cnn_module_kernel": str(cnn_module_kernel),
"right_context_length": str(right_context_length),
"left_context_length": str(left_context_length),
"encoder_dim": str(encoder_dim),
}
logging.info(f"meta_data: {meta_data}")
x = torch.zeros(1, T, 80, dtype=torch.float32)
states = encoder_model.encoder.init_states()
# Each layer has 4 states
assert len(states) == num_encoder_layers * 4, (len(states), num_encoder_layers)
# layer 0:
# state0: (memory_size, 1, encoder_dim)
# state1: (left_context_length, 1, encoder_dim)
# state2: (left_context_length, 1, encoder_dim)
# state3: (1, encoder_dim, cnn_module_kernel-1)
inputs = {}
input_names = ["x"]
outputs = {}
output_names = ["encoder_out"]
def build_inputs_outputs(s, name):
assert len(s) == 4, len(s)
logging.info(f"{name}_0.shape: {s[0].shape}")
input_names.append(f"{name}_0")
inputs[f"{name}_0"] = {1: "N"}
output_names.append(f"new_{name}_0")
logging.info(f"{name}_1.shape: {s[1].shape}")
input_names.append(f"{name}_1")
inputs[f"{name}_1"] = {1: "N"}
output_names.append(f"new_{name}_1")
logging.info(f"{name}_2.shape: {s[2].shape}")
input_names.append(f"{name}_2")
inputs[f"{name}_2"] = {1: "N"}
output_names.append(f"new_{name}_2")
logging.info(f"{name}_3.shape: {s[3].shape}")
input_names.append(f"{name}_3")
inputs[f"{name}_3"] = {0: "N"}
output_names.append(f"new_{name}_3")
for i in range(num_encoder_layers):
base_name = f"layer{i}"
s = states[i * 4 : (i + 1) * 4]
build_inputs_outputs(s, base_name)
torch.onnx.export(
encoder_model,
(x, states),
encoder_filename,
verbose=False,
opset_version=opset_version,
input_names=input_names,
output_names=output_names,
dynamic_axes={
"x": {0: "N"},
"encoder_out": {0: "N"},
**inputs,
**outputs,
},
)
add_meta_data(filename=encoder_filename, meta_data=meta_data)
def export_decoder_model_onnx(
decoder_model: OnnxDecoder,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.
The exported model has one input:
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
and has one output:
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
context_size = decoder_model.decoder.context_size
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
torch.onnx.export(
decoder_model,
y,
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
)
meta_data = {
"context_size": str(context_size),
"vocab_size": str(vocab_size),
}
add_meta_data(filename=decoder_filename, meta_data=meta_data)
def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported joiner model has two inputs:
- encoder_out: a tensor of shape (N, joiner_dim)
- decoder_out: a tensor of shape (N, joiner_dim)
and produces one output:
- logit: a tensor of shape (N, vocab_size)
"""
joiner_dim = joiner_model.output_linear.weight.shape[1]
logging.info(f"joiner dim: {joiner_dim}")
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model,
(projected_encoder_out, projected_decoder_out),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"encoder_out",
"decoder_out",
],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
meta_data = {
"joiner_dim": str(joiner_dim),
}
add_meta_data(filename=joiner_filename, meta_data=meta_data)
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.is_pnnx = False
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
setup_logger(f"{params.exp_dir}/log-export/log-export-onnx")
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
convert_scaled_to_non_scaled(model, inplace=True)
encoder = OnnxEncoder(
encoder=model.encoder,
encoder_proj=model.joiner.encoder_proj,
)
decoder = OnnxDecoder(
decoder=model.decoder,
decoder_proj=model.joiner.decoder_proj,
)
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
logging.info(f"encoder parameters: {encoder_num_param}")
logging.info(f"decoder parameters: {decoder_num_param}")
logging.info(f"joiner parameters: {joiner_num_param}")
logging.info(f"total parameters: {total_num_param}")
if params.iter > 0:
suffix = f"iter-{params.iter}"
else:
suffix = f"epoch-{params.epoch}"
suffix += f"-avg-{params.avg}"
opset_version = 13
logging.info("Exporting encoder")
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
export_encoder_model_onnx(
encoder,
encoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported encoder to {encoder_filename}")
logging.info("Exporting decoder")
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
export_decoder_model_onnx(
decoder,
decoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported decoder to {decoder_filename}")
logging.info("Exporting joiner")
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
export_joiner_model_onnx(
joiner,
joiner_filename,
opset_version=opset_version,
)
logging.info(f"Exported joiner to {joiner_filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
main()

View File

@ -64,6 +64,7 @@ from pathlib import Path
import sentencepiece as spm import sentencepiece as spm
import torch import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import (
@ -258,6 +259,7 @@ def main():
model.eval() model.eval()
if params.jit: if params.jit:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore # We won't use the forward() method of the model in C++, so just ignore
# it here. # it here.
# Otherwise, one of its arguments is a ragged tensor and is not # Otherwise, one of its arguments is a ragged tensor and is not

View File

@ -0,0 +1,456 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
"""
This script loads ONNX models exported by ./export-onnx.py
and uses them to decode waves.
We use the pre-trained model from
https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt"
cd exp
ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-99.pt
popd
2. Export the model to ONNX
./conv_emformer_transducer_stateless2/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--num-encoder-layers 12 \
--chunk-length 32 \
--cnn-module-kernel 31 \
--left-context-length 32 \
--right-context-length 8 \
--memory-size 32
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
3. Run this file with the exported ONNX models
./conv_emformer_transducer_stateless2/onnx_pretrained.py \
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav
Note: Even though this script only supports decoding a single file,
the exported ONNX models do support batch processing.
"""
import argparse
import logging
from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import onnxruntime as ort
import torch
import torchaudio
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder onnx model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder onnx model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner onnx model. ",
)
parser.add_argument(
"--tokens",
type=str,
help="""Path to tokens.txt.""",
)
parser.add_argument(
"sound_file",
type=str,
help="The input sound file to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
class OnnxModel:
def __init__(
self,
encoder_model_filename: str,
decoder_model_filename: str,
joiner_model_filename: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.init_encoder(encoder_model_filename)
self.init_decoder(decoder_model_filename)
self.init_joiner(joiner_model_filename)
def init_encoder(self, encoder_model_filename: str):
self.encoder = ort.InferenceSession(
encoder_model_filename,
sess_options=self.session_opts,
)
self.init_encoder_states()
def init_encoder_states(self, batch_size: int = 1):
encoder_meta = self.encoder.get_modelmeta().custom_metadata_map
model_type = encoder_meta["model_type"]
assert model_type == "conv-emformer", model_type
decode_chunk_len = int(encoder_meta["decode_chunk_len"])
T = int(encoder_meta["T"])
num_encoder_layers = int(encoder_meta["num_encoder_layers"])
memory_size = int(encoder_meta["memory_size"])
cnn_module_kernel = int(encoder_meta["cnn_module_kernel"])
right_context_length = int(encoder_meta["right_context_length"])
left_context_length = int(encoder_meta["left_context_length"])
encoder_dim = int(encoder_meta["encoder_dim"])
logging.info(f"decode_chunk_len: {decode_chunk_len}")
logging.info(f"T: {T}")
logging.info(f"num_encoder_layers: {num_encoder_layers}")
logging.info(f"memory_size: {memory_size}")
logging.info(f"cnn_module_kernel: {cnn_module_kernel}")
logging.info(f"left_context_length: {left_context_length} (after subsampling)")
logging.info(f"right_context_length: {right_context_length}")
logging.info(f"encoder_dim: {encoder_dim}")
N = batch_size
states = []
for i in range(num_encoder_layers):
s0 = torch.zeros(memory_size, N, encoder_dim)
s1 = torch.zeros(left_context_length, N, encoder_dim)
s2 = torch.zeros(left_context_length, N, encoder_dim)
s3 = torch.zeros(N, encoder_dim, cnn_module_kernel - 1)
states.extend([s0, s1, s2, s3])
self.states = states
self.segment = T
self.offset = decode_chunk_len
self.num_encoder_layers = num_encoder_layers
def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession(
decoder_model_filename,
sess_options=self.session_opts,
)
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
self.context_size = int(decoder_meta["context_size"])
self.vocab_size = int(decoder_meta["vocab_size"])
logging.info(f"context_size: {self.context_size}")
logging.info(f"vocab_size: {self.vocab_size}")
def init_joiner(self, joiner_model_filename: str):
self.joiner = ort.InferenceSession(
joiner_model_filename,
sess_options=self.session_opts,
)
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
self.joiner_dim = int(joiner_meta["joiner_dim"])
logging.info(f"joiner_dim: {self.joiner_dim}")
def _build_encoder_input_output(
self,
x: torch.Tensor,
) -> Tuple[Dict[str, np.ndarray], List[str]]:
encoder_input = {"x": x.numpy()}
encoder_output = ["encoder_out"]
def build_inputs_outputs(states: List[torch.Tensor], name: str):
for i in range(4):
if isinstance(states[i], torch.Tensor):
encoder_input[f"{name}_{i}"] = states[i].numpy()
else:
encoder_input[f"{name}_{i}"] = states[i]
encoder_output.append(f"new_{name}_{i}")
for i in range(self.num_encoder_layers):
base_name = f"layer{i}"
s = self.states[i * 4 : (i + 1) * 4]
build_inputs_outputs(s, base_name)
return encoder_input, encoder_output
def _update_states(self, states: List[np.ndarray]):
self.states = states
def run_encoder(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C)
Returns:
Return a 3-D tensor of shape (N, T', joiner_dim) where
T' is usually equal to ((T-7)//2+1)//2
"""
encoder_input, encoder_output_names = self._build_encoder_input_output(x)
out = self.encoder.run(encoder_output_names, encoder_input)
self._update_states(out[1:])
return torch.from_numpy(out[0])
def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor:
"""
Args:
decoder_input:
A 2-D tensor of shape (N, context_size)
Returns:
Return a 2-D tensor of shape (N, joiner_dim)
"""
out = self.decoder.run(
[self.decoder.get_outputs()[0].name],
{self.decoder.get_inputs()[0].name: decoder_input.numpy()},
)[0]
return torch.from_numpy(out)
def run_joiner(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
) -> torch.Tensor:
"""
Args:
encoder_out:
A 2-D tensor of shape (N, joiner_dim)
decoder_out:
A 2-D tensor of shape (N, joiner_dim)
Returns:
Return a 2-D tensor of shape (N, vocab_size)
"""
out = self.joiner.run(
[self.joiner.get_outputs()[0].name],
{
self.joiner.get_inputs()[0].name: encoder_out.numpy(),
self.joiner.get_inputs()[1].name: decoder_out.numpy(),
},
)[0]
return torch.from_numpy(out)
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
def create_streaming_feature_extractor() -> OnlineFeature:
"""Create a CPU streaming feature extractor.
At present, we assume it returns a fbank feature extractor with
fixed options. In the future, we will support passing in the options
from outside.
Returns:
Return a CPU streaming feature extractor.
"""
opts = FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
return OnlineFbank(opts)
def greedy_search(
model: OnnxModel,
encoder_out: torch.Tensor,
context_size: int,
decoder_out: Optional[torch.Tensor] = None,
hyp: Optional[List[int]] = None,
) -> List[int]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
A 3-D tensor of shape (1, T, joiner_dim)
context_size:
The context size of the decoder model.
decoder_out:
Optional. Decoder output of the previous chunk.
hyp:
Decoding results for previous chunks.
Returns:
Return the decoded results so far.
"""
blank_id = 0
if decoder_out is None:
assert hyp is None, hyp
hyp = [blank_id] * context_size
decoder_input = torch.tensor([hyp], dtype=torch.int64)
decoder_out = model.run_decoder(decoder_input)
else:
assert hyp is not None, hyp
encoder_out = encoder_out.squeeze(0)
T = encoder_out.size(0)
for t in range(T):
cur_encoder_out = encoder_out[t : t + 1]
joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0)
y = joiner_out.argmax(dim=0).item()
if y != blank_id:
hyp.append(y)
decoder_input = hyp[-context_size:]
decoder_input = torch.tensor([decoder_input], dtype=torch.int64)
decoder_out = model.run_decoder(decoder_input)
return hyp, decoder_out
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
model = OnnxModel(
encoder_model_filename=args.encoder_model_filename,
decoder_model_filename=args.decoder_model_filename,
joiner_model_filename=args.joiner_model_filename,
)
sample_rate = 16000
logging.info("Constructing Fbank computer")
online_fbank = create_streaming_feature_extractor()
logging.info(f"Reading sound files: {args.sound_file}")
waves = read_sound_files(
filenames=[args.sound_file],
expected_sample_rate=sample_rate,
)[0]
tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32)
wave_samples = torch.cat([waves, tail_padding])
num_processed_frames = 0
segment = model.segment
offset = model.offset
context_size = model.context_size
hyp = None
decoder_out = None
chunk = int(1 * sample_rate) # 1 second
start = 0
while start < wave_samples.numel():
end = min(start + chunk, wave_samples.numel())
samples = wave_samples[start:end]
start += chunk
online_fbank.accept_waveform(
sampling_rate=sample_rate,
waveform=samples,
)
while online_fbank.num_frames_ready - num_processed_frames >= segment:
frames = []
for i in range(segment):
frames.append(online_fbank.get_frame(num_processed_frames + i))
num_processed_frames += offset
frames = torch.cat(frames, dim=0)
frames = frames.unsqueeze(0)
encoder_out = model.run_encoder(frames)
hyp, decoder_out = greedy_search(
model,
encoder_out,
context_size,
decoder_out,
hyp,
)
symbol_table = k2.SymbolTable.from_file(args.tokens)
text = ""
for i in hyp[context_size:]:
text += symbol_table[i]
text = text.replace("", " ").strip()
logging.info(args.sound_file)
logging.info(text)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -425,6 +425,7 @@ def get_params() -> AttributeDict:
"joiner_dim": 512, "joiner_dim": 512,
# parameters for Noam # parameters for Noam
"model_warm_step": 3000, # arg given to model, not for lrate "model_warm_step": 3000, # arg given to model, not for lrate
"is_pnnx": True,
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )
@ -446,6 +447,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
left_context_length=params.left_context_length, left_context_length=params.left_context_length,
right_context_length=params.right_context_length, right_context_length=params.right_context_length,
memory_size=params.memory_size, memory_size=params.memory_size,
is_pnnx=params.is_pnnx,
) )
return encoder return encoder

View File

@ -0,0 +1 @@
../lstm_transducer_stateless2/export-onnx.py

View File

@ -0,0 +1 @@
../lstm_transducer_stateless2/onnx_check.py

View File

@ -0,0 +1 @@
../lstm_transducer_stateless2/onnx_pretrained.py

View File

@ -0,0 +1,337 @@
#!/usr/bin/env python3
"""
Please see
https://k2-fsa.github.io/icefall/model-export/export-ncnn.html
for more details about how to use this file.
We use the pre-trained model from
https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt"
cd exp
ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt
popd
2. Export via torch.jit.trace()
./lstm_transducer_stateless2/export-for-ncnn.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
cd ./lstm_transducer_stateless2/exp
pnnx encoder_jit_trace-pnnx.pt
pnnx decoder_jit_trace-pnnx.pt
pnnx joiner_jit_trace-pnnx.pt
See ./streaming-ncnn-decode.py
and
https://github.com/k2-fsa/sherpa-ncnn
for usage.
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
add_model_arguments(parser)
return parser
def export_encoder_model_jit_trace(
encoder_model: torch.nn.Module,
encoder_filename: str,
) -> None:
"""Export the given encoder model with torch.jit.trace()
Note: The warmup argument is fixed to 1.
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported model.
"""
x = torch.zeros(1, 100, 80, dtype=torch.float32)
x_lens = torch.tensor([100], dtype=torch.int64)
states = encoder_model.get_init_states()
traced_model = torch.jit.trace(encoder_model, (x, x_lens, states))
traced_model.save(encoder_filename)
logging.info(f"Saved to {encoder_filename}")
def export_decoder_model_jit_trace(
decoder_model: torch.nn.Module,
decoder_filename: str,
) -> None:
"""Export the given decoder model with torch.jit.trace()
Note: The argument need_pad is fixed to False.
Args:
decoder_model:
The input decoder model
decoder_filename:
The filename to save the exported model.
"""
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
need_pad = torch.tensor([False])
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
traced_model.save(decoder_filename)
logging.info(f"Saved to {decoder_filename}")
def export_joiner_model_jit_trace(
joiner_model: torch.nn.Module,
joiner_filename: str,
) -> None:
"""Export the given joiner model with torch.jit.trace()
Note: The argument project_input is fixed to True. A user should not
project the encoder_out/decoder_out by himself/herself. The exported joiner
will do that for the user.
Args:
joiner_model:
The input joiner model
joiner_filename:
The filename to save the exported model.
"""
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
traced_model.save(joiner_filename)
logging.info(f"Saved to {joiner_filename}")
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn")
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
params.is_pnnx = True
logging.info("About to create model")
model = get_transducer_model(params, enable_giga=False)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
convert_scaled_to_non_scaled(model, inplace=True)
logging.info("Using torch.jit.trace()")
logging.info("Exporting encoder")
encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt"
export_encoder_model_jit_trace(model.encoder, encoder_filename)
logging.info("Exporting decoder")
decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt"
export_decoder_model_jit_trace(model.decoder, decoder_filename)
logging.info("Exporting joiner")
joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt"
export_joiner_model_jit_trace(model.joiner, joiner_filename)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
main()

View File

@ -0,0 +1,593 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
"""
This script exports a transducer model from PyTorch to ONNX.
We use the pre-trained model from
https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt"
cd exp
ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt
popd
2. Export the model to ONNX
./lstm_transducer_stateless2/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
See ./onnx_pretrained.py and ./onnx_check.py for how to
use the exported ONNX models.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict, Optional, Tuple
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from decoder import Decoder
from lstm import RNN
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless5/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
onnx.save(model, filename)
class OnnxEncoder(nn.Module):
"""A wrapper for RNN and the encoder_proj from the joiner"""
def __init__(self, encoder: RNN, encoder_proj: nn.Linear):
"""
Args:
encoder:
An RNN encoder.
encoder_proj:
The projection layer for encoder from the joiner.
"""
super().__init__()
self.encoder = encoder
self.encoder_proj = encoder_proj
def forward(
self,
x: torch.Tensor,
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Please see the help information of RNN.forward
Args:
x:
A 3-D tensor of shape (N, T, C)
states:
A tuple of 2 tensors (optional). It is for streaming inference.
states[0] is the hidden states of all layers,
with shape of (num_layers, N, d_model);
states[1] is the cell states of all layers,
with shape of (num_layers, N, rnn_hidden_size).
Returns:
Return a tuple containing:
- encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
- updated states, whose shape is the same as the input states.
"""
N = x.size(0)
T = x.size(1)
x_lens = torch.tensor([T] * N, dtype=torch.int64, device=x.device)
encoder_out, _, next_states = self.encoder(x, x_lens, states)
encoder_out = self.encoder_proj(encoder_out)
# Now encoder_out is of shape (N, T, joiner_dim)
return encoder_out, next_states
class OnnxDecoder(nn.Module):
"""A wrapper for Decoder and the decoder_proj from the joiner"""
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
super().__init__()
self.decoder = decoder
self.decoder_proj = decoder_proj
def forward(self, y: torch.Tensor) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, context_size).
Returns
Return a 2-D tensor of shape (N, joiner_dim)
"""
need_pad = False
decoder_output = self.decoder(y, need_pad=need_pad)
decoder_output = decoder_output.squeeze(1)
output = self.decoder_proj(decoder_output)
return output
class OnnxJoiner(nn.Module):
"""A wrapper for the joiner"""
def __init__(self, output_linear: nn.Linear):
super().__init__()
self.output_linear = output_linear
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
) -> torch.Tensor:
"""
Args:
encoder_out:
A 2-D tensor of shape (N, joiner_dim)
decoder_out:
A 2-D tensor of shape (N, joiner_dim)
Returns:
Return a 2-D tensor of shape (N, vocab_size)
"""
logit = encoder_out + decoder_out
logit = self.output_linear(torch.tanh(logit))
return logit
def export_encoder_model_onnx(
encoder_model: OnnxEncoder,
encoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the given encoder model to ONNX format.
The exported model has the following inputs:
- x, a tensor of shape (N, T, C); dtype is torch.float32
- state0, a tensor of shape (num_encoder_layers, batch_size, d_model)
- state1, a tensor of shape (num_encoder_layers, batch_size, rnn_hidden_size)
and it has 3 outputs:
- encoder_out, a tensor of shape (N, T', joiner_dim)
- new_state0, a tensor of shape (num_encoder_layers, batch_size, d_model)
- new_state1, a tensor of shape (num_encoder_layers, batch_size, rnn_hidden_size)
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
num_encoder_layers = encoder_model.encoder.num_encoder_layers
d_model = encoder_model.encoder.d_model
rnn_hidden_size = encoder_model.encoder.rnn_hidden_size
decode_chunk_len = 4
T = 9
x = torch.zeros(1, T, 80, dtype=torch.float32)
states = encoder_model.encoder.get_init_states()
# state0: (num_encoder_layers, batch_size, d_model)
# state1: (num_encoder_layers, batch_size, rnn_hidden_size)
torch.onnx.export(
encoder_model,
(x, states),
encoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "state0", "state1"],
output_names=["encoder_out", "new_state0", "new_state1"],
dynamic_axes={
"x": {0: "N"},
"state0": {1: "N"},
"state1": {1: "N"},
"encoder_out": {0: "N"},
"new_state0": {1: "N"},
"new_state1": {1: "N"},
},
)
meta_data = {
"model_type": "lstm",
"version": "1",
"model_author": "k2-fsa",
"decode_chunk_len": str(decode_chunk_len), # 32
"T": str(T), # 39
"num_encoder_layers": str(num_encoder_layers),
"d_model": str(d_model),
"rnn_hidden_size": str(rnn_hidden_size),
}
logging.info(f"meta_data: {meta_data}")
add_meta_data(filename=encoder_filename, meta_data=meta_data)
def export_decoder_model_onnx(
decoder_model: OnnxDecoder,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.
The exported model has one input:
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
and has one output:
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
context_size = decoder_model.decoder.context_size
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
torch.onnx.export(
decoder_model,
y,
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
)
meta_data = {
"context_size": str(context_size),
"vocab_size": str(vocab_size),
}
add_meta_data(filename=decoder_filename, meta_data=meta_data)
def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported joiner model has two inputs:
- encoder_out: a tensor of shape (N, joiner_dim)
- decoder_out: a tensor of shape (N, joiner_dim)
and produces one output:
- logit: a tensor of shape (N, vocab_size)
"""
joiner_dim = joiner_model.output_linear.weight.shape[1]
logging.info(f"joiner dim: {joiner_dim}")
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model,
(projected_encoder_out, projected_decoder_out),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"encoder_out",
"decoder_out",
],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
meta_data = {
"joiner_dim": str(joiner_dim),
}
add_meta_data(filename=joiner_filename, meta_data=meta_data)
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
setup_logger(f"{params.exp_dir}/log-export/log-export-onnx")
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params, enable_giga=False)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
encoder = OnnxEncoder(
encoder=model.encoder,
encoder_proj=model.joiner.encoder_proj,
)
decoder = OnnxDecoder(
decoder=model.decoder,
decoder_proj=model.joiner.decoder_proj,
)
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
logging.info(f"encoder parameters: {encoder_num_param}")
logging.info(f"decoder parameters: {decoder_num_param}")
logging.info(f"joiner parameters: {joiner_num_param}")
logging.info(f"total parameters: {total_num_param}")
if params.iter > 0:
suffix = f"iter-{params.iter}"
else:
suffix = f"epoch-{params.epoch}"
suffix += f"-avg-{params.avg}"
opset_version = 13
logging.info("Exporting encoder")
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
export_encoder_model_onnx(
encoder,
encoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported encoder to {encoder_filename}")
logging.info("Exporting decoder")
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
export_decoder_model_onnx(
decoder,
decoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported decoder to {decoder_filename}")
logging.info("Exporting joiner")
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
export_joiner_model_onnx(
joiner,
joiner_filename,
opset_version=opset_version,
)
logging.info(f"Exported joiner to {joiner_filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
main()

View File

@ -74,29 +74,6 @@ with the following commands:
git lfs install git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
# You will find the pre-trained models in icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp # You will find the pre-trained models in icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp
(3) Export to ONNX format
./lstm_transducer_stateless2/export.py \
--exp-dir ./lstm_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10 \
--onnx 1
It will generate the following files in the given `exp_dir`.
- encoder.onnx
- decoder.onnx
- joiner.onnx
- joiner_encoder_proj.onnx
- joiner_decoder_proj.onnx
Please see ./streaming-onnx-decode.py for usage of the generated files
Check
https://github.com/k2-fsa/sherpa-onnx
for how to use the exported models outside of icefall.
""" """
import argparse import argparse
@ -192,35 +169,6 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--pnnx",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.trace for later
converting to PNNX. It will generate 3 files:
- encoder_jit_trace-pnnx.pt
- decoder_jit_trace-pnnx.pt
- joiner_jit_trace-pnnx.pt
""",
)
parser.add_argument(
"--onnx",
type=str2bool,
default=False,
help="""If True, --jit and --pnnx are ignored and it exports the model
to onnx format. It will generate the following files:
- encoder.onnx
- decoder.onnx
- joiner.onnx
- joiner_encoder_proj.onnx
- joiner_decoder_proj.onnx
Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them.
""",
)
parser.add_argument( parser.add_argument(
"--context-size", "--context-size",
type=int, type=int,
@ -305,209 +253,6 @@ def export_joiner_model_jit_trace(
logging.info(f"Saved to {joiner_filename}") logging.info(f"Saved to {joiner_filename}")
def export_encoder_model_onnx(
encoder_model: nn.Module,
encoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the given encoder model to ONNX format.
The exported model has 3 inputs:
- x, a tensor of shape (N, T, C); dtype is torch.float32
- x_lens, a tensor of shape (N,); dtype is torch.int64
- states: a tuple containing:
- h0: a tensor of shape (num_layers, N, proj_size)
- c0: a tensor of shape (num_layers, N, hidden_size)
and it has 3 outputs:
- encoder_out, a tensor of shape (N, T, C)
- encoder_out_lens, a tensor of shape (N,)
- states: a tuple containing:
- next_h0: a tensor of shape (num_layers, N, proj_size)
- next_c0: a tensor of shape (num_layers, N, hidden_size)
Note: The warmup argument is fixed to 1.
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
N = 1
x = torch.zeros(N, 9, 80, dtype=torch.float32)
x_lens = torch.tensor([9], dtype=torch.int64)
h = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.d_model)
c = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.rnn_hidden_size)
warmup = 1.0
torch.onnx.export(
encoder_model, # use torch.jit.trace() internally
(x, x_lens, (h, c), warmup),
encoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "x_lens", "h", "c", "warmup"],
output_names=["encoder_out", "encoder_out_lens", "next_h", "next_c"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"h": {1: "N"},
"c": {1: "N"},
"encoder_out": {0: "N", 1: "T"},
"encoder_out_lens": {0: "N"},
"next_h": {1: "N"},
"next_c": {1: "N"},
},
)
logging.info(f"Saved to {encoder_filename}")
def export_decoder_model_onnx(
decoder_model: nn.Module,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.
The exported model has one input:
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
and has one output:
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
Note: The argument need_pad is fixed to False.
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
need_pad = False # Always False, so we can use torch.jit.trace() here
# Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script()
# in this case
torch.onnx.export(
decoder_model,
(y, need_pad),
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y", "need_pad"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
)
logging.info(f"Saved to {decoder_filename}")
def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported joiner model has two inputs:
- projected_encoder_out: a tensor of shape (N, joiner_dim)
- projected_decoder_out: a tensor of shape (N, joiner_dim)
and produces one output:
- logit: a tensor of shape (N, vocab_size)
The exported encoder_proj model has one input:
- encoder_out: a tensor of shape (N, encoder_out_dim)
and produces one output:
- projected_encoder_out: a tensor of shape (N, joiner_dim)
The exported decoder_proj model has one input:
- decoder_out: a tensor of shape (N, decoder_out_dim)
and produces one output:
- projected_decoder_out: a tensor of shape (N, joiner_dim)
"""
encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
joiner_dim = joiner_model.decoder_proj.weight.shape[0]
projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
project_input = False
# Note: It uses torch.jit.trace() internally
torch.onnx.export(
joiner_model,
(projected_encoder_out, projected_decoder_out, project_input),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"projected_encoder_out",
"projected_decoder_out",
"project_input",
],
output_names=["logit"],
dynamic_axes={
"projected_encoder_out": {0: "N"},
"projected_decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
logging.info(f"Saved to {joiner_filename}")
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model.encoder_proj,
encoder_out,
encoder_proj_filename,
verbose=False,
opset_version=opset_version,
input_names=["encoder_out"],
output_names=["projected_encoder_out"],
dynamic_axes={
"encoder_out": {0: "N"},
"projected_encoder_out": {0: "N"},
},
)
logging.info(f"Saved to {encoder_proj_filename}")
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model.decoder_proj,
decoder_out,
decoder_proj_filename,
verbose=False,
opset_version=opset_version,
input_names=["decoder_out"],
output_names=["projected_decoder_out"],
dynamic_axes={
"decoder_out": {0: "N"},
"projected_decoder_out": {0: "N"},
},
)
logging.info(f"Saved to {decoder_proj_filename}")
@torch.no_grad() @torch.no_grad()
def main(): def main():
args = get_parser().parse_args() args = get_parser().parse_args()
@ -531,10 +276,6 @@ def main():
logging.info(params) logging.info(params)
if params.pnnx:
params.is_pnnx = params.pnnx
logging.info("For PNNX")
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params, enable_giga=False) model = get_transducer_model(params, enable_giga=False)
@ -629,44 +370,7 @@ def main():
model.to("cpu") model.to("cpu")
model.eval() model.eval()
if params.onnx: if params.jit_trace is True:
logging.info("Export model to ONNX format")
convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
opset_version = 11
encoder_filename = params.exp_dir / "encoder.onnx"
export_encoder_model_onnx(
model.encoder,
encoder_filename,
opset_version=opset_version,
)
decoder_filename = params.exp_dir / "decoder.onnx"
export_decoder_model_onnx(
model.decoder,
decoder_filename,
opset_version=opset_version,
)
joiner_filename = params.exp_dir / "joiner.onnx"
export_joiner_model_onnx(
model.joiner,
joiner_filename,
opset_version=opset_version,
)
elif params.pnnx:
convert_scaled_to_non_scaled(model, inplace=True)
logging.info("Using torch.jit.trace()")
encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt"
export_encoder_model_jit_trace(model.encoder, encoder_filename)
decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt"
export_decoder_model_jit_trace(model.decoder, decoder_filename)
joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt"
export_joiner_model_jit_trace(model.joiner, joiner_filename)
elif params.jit_trace is True:
convert_scaled_to_non_scaled(model, inplace=True) convert_scaled_to_non_scaled(model, inplace=True)
logging.info("Using torch.jit.trace()") logging.info("Using torch.jit.trace()")
encoder_filename = params.exp_dir / "encoder_jit_trace.pt" encoder_filename = params.exp_dir / "encoder_jit_trace.pt"

View File

@ -19,7 +19,7 @@
""" """
Usage: Usage:
./lstm_transducer_stateless2/ncnn-decode.py \ ./lstm_transducer_stateless2/ncnn-decode.py \
--bpe-model-filename ./data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \ --encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \
--encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \ --encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \
--decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \ --decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \
@ -27,15 +27,19 @@ Usage:
--joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \ --joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \
--joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \ --joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \
./test_wavs/1089-134686-0001.wav ./test_wavs/1089-134686-0001.wav
Please see
https://k2-fsa.github.io/icefall/model-export/export-ncnn.html
for details.
""" """
import argparse import argparse
import logging import logging
from typing import List from typing import List
import k2
import kaldifeat import kaldifeat
import ncnn import ncnn
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
@ -44,9 +48,9 @@ def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--bpe-model-filename", "--tokens",
type=str, type=str,
help="Path to bpe.model", help="Path to tokens.txt",
) )
parser.add_argument( parser.add_argument(
@ -240,9 +244,6 @@ def main():
model = Model(args) model = Model(args)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model_filename)
sound_file = args.sound_filename sound_file = args.sound_filename
sample_rate = 16000 sample_rate = 16000
@ -280,8 +281,16 @@ def main():
encoder_out, encoder_out_lens, hx, cx = model.run_encoder(features, states) encoder_out, encoder_out_lens, hx, cx = model.run_encoder(features, states)
hyp = greedy_search(model, encoder_out) hyp = greedy_search(model, encoder_out)
symbol_table = k2.SymbolTable.from_file(args.tokens)
text = ""
for i in hyp:
text += symbol_table[i]
text = text.replace("", " ").strip()
logging.info(sound_file) logging.info(sound_file)
logging.info(sp.decode(hyp)) logging.info(text)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -0,0 +1,261 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
"""
This script checks that exported ONNX models produce the same output
with the given torchscript model for the same input.
We use the pre-trained model from
https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt"
cd exp
ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt
popd
2. Export the model via torch.jit.trace()
./lstm_transducer_stateless2/export.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp/ \
--jit-trace 1
It will generate the following 3 files inside $repo/exp
- encoder_jit_trace.pt
- decoder_jit_trace.pt
- joiner_jit_trace.pt
3. Export the model to ONNX
./lstm_transducer_stateless2/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
4. Run this file
./lstm_transducer_stateless2/onnx_check.py \
--jit-encoder-filename $repo/exp/encoder_jit_trace.pt \
--jit-decoder-filename $repo/exp/decoder_jit_trace.pt \
--jit-joiner-filename $repo/exp/joiner_jit_trace.pt \
--onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx
"""
import argparse
import logging
from onnx_pretrained import OnnxModel
from icefall import is_module_available
import torch
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--jit-encoder-filename",
required=True,
type=str,
help="Path to the torchscript encoder model",
)
parser.add_argument(
"--jit-decoder-filename",
required=True,
type=str,
help="Path to the torchscript decoder model",
)
parser.add_argument(
"--jit-joiner-filename",
required=True,
type=str,
help="Path to the torchscript joiner model",
)
parser.add_argument(
"--onnx-encoder-filename",
required=True,
type=str,
help="Path to the ONNX encoder model",
)
parser.add_argument(
"--onnx-decoder-filename",
required=True,
type=str,
help="Path to the ONNX decoder model",
)
parser.add_argument(
"--onnx-joiner-filename",
required=True,
type=str,
help="Path to the ONNX joiner model",
)
return parser
def test_encoder(
torch_encoder_model: torch.jit.ScriptModule,
torch_encoder_proj_model: torch.jit.ScriptModule,
onnx_model: OnnxModel,
):
N = torch.randint(1, 100, size=(1,)).item()
T = onnx_model.segment
C = 80
x_lens = torch.tensor([T] * N)
torch_states = torch_encoder_model.get_init_states(N)
onnx_model.init_encoder_states(N)
for i in range(5):
logging.info(f"test_encoder: iter {i}")
x = torch.rand(N, T, C)
torch_encoder_out, _, torch_states = torch_encoder_model(
x, x_lens, torch_states
)
torch_encoder_out = torch_encoder_proj_model(torch_encoder_out)
onnx_encoder_out = onnx_model.run_encoder(x)
assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-4), (
(torch_encoder_out - onnx_encoder_out).abs().max()
)
def test_decoder(
torch_decoder_model: torch.jit.ScriptModule,
torch_decoder_proj_model: torch.jit.ScriptModule,
onnx_model: OnnxModel,
):
context_size = onnx_model.context_size
vocab_size = onnx_model.vocab_size
for i in range(10):
N = torch.randint(1, 100, size=(1,)).item()
logging.info(f"test_decoder: iter {i}, N={N}")
x = torch.randint(
low=1,
high=vocab_size,
size=(N, context_size),
dtype=torch.int64,
)
torch_decoder_out = torch_decoder_model(x, need_pad=torch.tensor([False]))
torch_decoder_out = torch_decoder_proj_model(torch_decoder_out)
torch_decoder_out = torch_decoder_out.squeeze(1)
onnx_decoder_out = onnx_model.run_decoder(x)
assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), (
(torch_decoder_out - onnx_decoder_out).abs().max()
)
def test_joiner(
torch_joiner_model: torch.jit.ScriptModule,
onnx_model: OnnxModel,
):
encoder_dim = torch_joiner_model.encoder_proj.weight.shape[1]
decoder_dim = torch_joiner_model.decoder_proj.weight.shape[1]
for i in range(10):
N = torch.randint(1, 100, size=(1,)).item()
logging.info(f"test_joiner: iter {i}, N={N}")
encoder_out = torch.rand(N, encoder_dim)
decoder_out = torch.rand(N, decoder_dim)
projected_encoder_out = torch_joiner_model.encoder_proj(encoder_out)
projected_decoder_out = torch_joiner_model.decoder_proj(decoder_out)
torch_joiner_out = torch_joiner_model(encoder_out, decoder_out)
onnx_joiner_out = onnx_model.run_joiner(
projected_encoder_out, projected_decoder_out
)
assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), (
(torch_joiner_out - onnx_joiner_out).abs().max()
)
@torch.no_grad()
def main():
args = get_parser().parse_args()
logging.info(vars(args))
torch_encoder_model = torch.jit.load(args.jit_encoder_filename)
torch_decoder_model = torch.jit.load(args.jit_decoder_filename)
torch_joiner_model = torch.jit.load(args.jit_joiner_filename)
onnx_model = OnnxModel(
encoder_model_filename=args.onnx_encoder_filename,
decoder_model_filename=args.onnx_decoder_filename,
joiner_model_filename=args.onnx_joiner_filename,
)
logging.info("Test encoder")
# When exporting the model to onnx, we have already put the encoder_proj
# inside the encoder.
test_encoder(torch_encoder_model, torch_joiner_model.encoder_proj, onnx_model)
logging.info("Test decoder")
# When exporting the model to onnx, we have already put the decoder_proj
# inside the decoder.
test_decoder(torch_decoder_model, torch_joiner_model.decoder_proj, onnx_model)
logging.info("Test joiner")
test_joiner(torch_joiner_model, onnx_model)
logging.info("Finished checking ONNX models")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
# See https://github.com/pytorch/pytorch/issues/38342
# and https://github.com/pytorch/pytorch/issues/33354
#
# If we don't do this, the delay increases whenever there is
# a new request that changes the actual batch size.
# If you use `py-spy dump --pid <server-pid> --native`, you will
# see a lot of time is spent in re-compiling the torch script model.
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._set_graph_executor_optimize(False)
if __name__ == "__main__":
torch.manual_seed(20230207)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,428 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
"""
This script loads ONNX models exported by ./export-onnx.py
and uses them to decode waves.
We use the pre-trained model from
https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt"
cd exp
ln -s exp/pretrained-iter-468000-avg-16.pt epoch-99.pt
popd
2. Export the model to ONNX
./lstm_transducer_stateless2/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
3. Run this file with the exported ONNX models
./lstm_transducer_stateless2/onnx_pretrained.py \
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1221-135766-0001.wav
Note: Even though this script only supports decoding a single file,
the exported ONNX models do support batch processing.
"""
import argparse
import logging
from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import onnxruntime as ort
import torch
import torchaudio
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder onnx model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder onnx model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner onnx model. ",
)
parser.add_argument(
"--tokens",
type=str,
help="""Path to tokens.txt.""",
)
parser.add_argument(
"sound_file",
type=str,
help="The input sound file to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
class OnnxModel:
def __init__(
self,
encoder_model_filename: str,
decoder_model_filename: str,
joiner_model_filename: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.init_encoder(encoder_model_filename)
self.init_decoder(decoder_model_filename)
self.init_joiner(joiner_model_filename)
def init_encoder(self, encoder_model_filename: str):
self.encoder = ort.InferenceSession(
encoder_model_filename,
sess_options=self.session_opts,
)
self.init_encoder_states()
def init_encoder_states(self, batch_size: int = 1):
encoder_meta = self.encoder.get_modelmeta().custom_metadata_map
model_type = encoder_meta["model_type"]
assert model_type == "lstm", model_type
decode_chunk_len = int(encoder_meta["decode_chunk_len"])
T = int(encoder_meta["T"])
num_encoder_layers = int(encoder_meta["num_encoder_layers"])
d_model = int(encoder_meta["d_model"])
rnn_hidden_size = int(encoder_meta["rnn_hidden_size"])
logging.info(f"decode_chunk_len: {decode_chunk_len}")
logging.info(f"T: {T}")
logging.info(f"num_encoder_layers: {num_encoder_layers}")
logging.info(f"d_model: {d_model}")
logging.info(f"rnn_hidden_size: {rnn_hidden_size}")
N = batch_size
s0 = torch.zeros(num_encoder_layers, N, d_model)
s1 = torch.zeros(num_encoder_layers, N, rnn_hidden_size)
states = [s0.numpy(), s1.numpy()]
self.states = states
self.segment = T
self.offset = decode_chunk_len
def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession(
decoder_model_filename,
sess_options=self.session_opts,
)
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
self.context_size = int(decoder_meta["context_size"])
self.vocab_size = int(decoder_meta["vocab_size"])
logging.info(f"context_size: {self.context_size}")
logging.info(f"vocab_size: {self.vocab_size}")
def init_joiner(self, joiner_model_filename: str):
self.joiner = ort.InferenceSession(
joiner_model_filename,
sess_options=self.session_opts,
)
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
self.joiner_dim = int(joiner_meta["joiner_dim"])
logging.info(f"joiner_dim: {self.joiner_dim}")
def _build_encoder_input_output(
self,
x: torch.Tensor,
) -> Tuple[Dict[str, np.ndarray], List[str]]:
encoder_input = {
"x": x.numpy(),
"state0": self.states[0],
"state1": self.states[1],
}
encoder_output = ["encoder_out", "new_state0", "new_state1"]
return encoder_input, encoder_output
def _update_states(self, states: List[np.ndarray]):
self.states = states
def run_encoder(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C)
Returns:
Return a 3-D tensor of shape (N, T', joiner_dim) where
T' is usually equal to ((T-3)//2-1)//2
"""
encoder_input, encoder_output_names = self._build_encoder_input_output(x)
out = self.encoder.run(encoder_output_names, encoder_input)
self._update_states(out[1:])
return torch.from_numpy(out[0])
def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor:
"""
Args:
decoder_input:
A 2-D tensor of shape (N, context_size)
Returns:
Return a 2-D tensor of shape (N, joiner_dim)
"""
out = self.decoder.run(
[self.decoder.get_outputs()[0].name],
{self.decoder.get_inputs()[0].name: decoder_input.numpy()},
)[0]
return torch.from_numpy(out)
def run_joiner(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
) -> torch.Tensor:
"""
Args:
encoder_out:
A 2-D tensor of shape (N, joiner_dim)
decoder_out:
A 2-D tensor of shape (N, joiner_dim)
Returns:
Return a 2-D tensor of shape (N, vocab_size)
"""
out = self.joiner.run(
[self.joiner.get_outputs()[0].name],
{
self.joiner.get_inputs()[0].name: encoder_out.numpy(),
self.joiner.get_inputs()[1].name: decoder_out.numpy(),
},
)[0]
return torch.from_numpy(out)
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
def create_streaming_feature_extractor() -> OnlineFeature:
"""Create a CPU streaming feature extractor.
At present, we assume it returns a fbank feature extractor with
fixed options. In the future, we will support passing in the options
from outside.
Returns:
Return a CPU streaming feature extractor.
"""
opts = FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
return OnlineFbank(opts)
def greedy_search(
model: OnnxModel,
encoder_out: torch.Tensor,
context_size: int,
decoder_out: Optional[torch.Tensor] = None,
hyp: Optional[List[int]] = None,
) -> List[int]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
A 3-D tensor of shape (1, T, joiner_dim)
context_size:
The context size of the decoder model.
decoder_out:
Optional. Decoder output of the previous chunk.
hyp:
Decoding results for previous chunks.
Returns:
Return the decoded results so far.
"""
blank_id = 0
if decoder_out is None:
assert hyp is None, hyp
hyp = [blank_id] * context_size
decoder_input = torch.tensor([hyp], dtype=torch.int64)
decoder_out = model.run_decoder(decoder_input)
else:
assert hyp is not None, hyp
encoder_out = encoder_out.squeeze(0)
T = encoder_out.size(0)
for t in range(T):
cur_encoder_out = encoder_out[t : t + 1]
joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0)
y = joiner_out.argmax(dim=0).item()
if y != blank_id:
hyp.append(y)
decoder_input = hyp[-context_size:]
decoder_input = torch.tensor([decoder_input], dtype=torch.int64)
decoder_out = model.run_decoder(decoder_input)
return hyp, decoder_out
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
model = OnnxModel(
encoder_model_filename=args.encoder_model_filename,
decoder_model_filename=args.decoder_model_filename,
joiner_model_filename=args.joiner_model_filename,
)
sample_rate = 16000
logging.info("Constructing Fbank computer")
online_fbank = create_streaming_feature_extractor()
logging.info(f"Reading sound files: {args.sound_file}")
waves = read_sound_files(
filenames=[args.sound_file],
expected_sample_rate=sample_rate,
)[0]
tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32)
wave_samples = torch.cat([waves, tail_padding])
num_processed_frames = 0
segment = model.segment
offset = model.offset
context_size = model.context_size
hyp = None
decoder_out = None
chunk = int(1 * sample_rate) # 1 second
start = 0
while start < wave_samples.numel():
end = min(start + chunk, wave_samples.numel())
samples = wave_samples[start:end]
start += chunk
online_fbank.accept_waveform(
sampling_rate=sample_rate,
waveform=samples,
)
while online_fbank.num_frames_ready - num_processed_frames >= segment:
frames = []
for i in range(segment):
frames.append(online_fbank.get_frame(num_processed_frames + i))
num_processed_frames += offset
frames = torch.cat(frames, dim=0)
frames = frames.unsqueeze(0)
encoder_out = model.run_encoder(frames)
hyp, decoder_out = greedy_search(
model,
encoder_out,
context_size,
decoder_out,
hyp,
)
symbol_table = k2.SymbolTable.from_file(args.tokens)
text = ""
for i in hyp[context_size:]:
text += symbol_table[i]
text = text.replace("", " ").strip()
logging.info(args.sound_file)
logging.info(text)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -16,13 +16,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Please see
https://k2-fsa.github.io/icefall/model-export/export-ncnn.html
for usage
"""
import argparse import argparse
import logging import logging
from typing import List, Optional from typing import List, Optional
import k2
import ncnn import ncnn
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
@ -32,9 +37,9 @@ def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--bpe-model-filename", "--tokens",
type=str, type=str,
help="Path to bpe.model", help="Path to tokens.txt",
) )
parser.add_argument( parser.add_argument(
@ -251,9 +256,6 @@ def main():
model = Model(args) model = Model(args)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model_filename)
sound_file = args.sound_filename sound_file = args.sound_filename
sample_rate = 16000 sample_rate = 16000
@ -329,10 +331,16 @@ def main():
model, encoder_out.squeeze(0), decoder_out, hyp model, encoder_out.squeeze(0), decoder_out, hyp
) )
symbol_table = k2.SymbolTable.from_file(args.tokens)
context_size = 2 context_size = 2
text = ""
for i in hyp[context_size:]:
text += symbol_table[i]
text = text.replace("", " ").strip()
logging.info(sound_file) logging.info(sound_file)
logging.info(sp.decode(hyp[context_size:])) logging.info(text)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -0,0 +1 @@
../lstm_transducer_stateless2/export-onnx.py

View File

@ -0,0 +1 @@
../lstm_transducer_stateless2/onnx_check.py

View File

@ -0,0 +1 @@
../lstm_transducer_stateless2/onnx_pretrained.py

View File

@ -125,13 +125,13 @@ def test_encoder(
onnx_model: OnnxModel, onnx_model: OnnxModel,
): ):
C = 80 C = 80
for i in range(10): for i in range(3):
N = torch.randint(low=1, high=20, size=(1,)).item() N = torch.randint(low=1, high=20, size=(1,)).item()
T = torch.randint(low=50, high=100, size=(1,)).item() T = torch.randint(low=30, high=50, size=(1,)).item()
logging.info(f"test_encoder: iter {i}, N={N}, T={T}") logging.info(f"test_encoder: iter {i}, N={N}, T={T}")
x = torch.rand(N, T, C) x = torch.rand(N, T, C)
x_lens = torch.randint(low=10, high=T + 1, size=(N,)) x_lens = torch.randint(low=30, high=T + 1, size=(N,))
x_lens[0] = T x_lens[0] = T
torch_encoder_out, torch_encoder_out_lens = torch_model.encoder(x, x_lens) torch_encoder_out, torch_encoder_out_lens = torch_model.encoder(x, x_lens)

View File

@ -0,0 +1,560 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
"""
This script exports a transducer model from PyTorch to ONNX.
We use the pre-trained model from
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-epoch-30-avg-9.pt"
cd exp
ln -s pretrained-epoch-30-avg-9.pt epoch-99.pt
popd
2. Export the model to ONNX
./pruned_transducer_stateless7/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--feedforward-dims "1024,1024,2048,2048,1024"
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
See ./onnx_pretrained.py and ./onnx_check.py for how to
use the exported ONNX models.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict, Tuple
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from decoder import Decoder
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from zipformer import Zipformer
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless5/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
onnx.save(model, filename)
class OnnxEncoder(nn.Module):
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
def __init__(self, encoder: Zipformer, encoder_proj: nn.Linear):
"""
Args:
encoder:
A Zipformer encoder.
encoder_proj:
The projection layer for encoder from the joiner.
"""
super().__init__()
self.encoder = encoder
self.encoder_proj = encoder_proj
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Please see the help information of Zipformer.forward
Args:
x:
A 3-D tensor of shape (N, T, C)
x_lens:
A 1-D tensor of shape (N,). Its dtype is torch.int64
Returns:
Return a tuple containing:
- encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
- encoder_out_lens, A 1-D tensor of shape (N,)
"""
encoder_out, encoder_out_lens = self.encoder(x, x_lens)
encoder_out = self.encoder_proj(encoder_out)
# Now encoder_out is of shape (N, T, joiner_dim)
return encoder_out, encoder_out_lens
class OnnxDecoder(nn.Module):
"""A wrapper for Decoder and the decoder_proj from the joiner"""
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
super().__init__()
self.decoder = decoder
self.decoder_proj = decoder_proj
def forward(self, y: torch.Tensor) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, context_size).
Returns
Return a 2-D tensor of shape (N, joiner_dim)
"""
need_pad = False
decoder_output = self.decoder(y, need_pad=need_pad)
decoder_output = decoder_output.squeeze(1)
output = self.decoder_proj(decoder_output)
return output
class OnnxJoiner(nn.Module):
"""A wrapper for the joiner"""
def __init__(self, output_linear: nn.Linear):
super().__init__()
self.output_linear = output_linear
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
) -> torch.Tensor:
"""
Args:
encoder_out:
A 2-D tensor of shape (N, joiner_dim)
decoder_out:
A 2-D tensor of shape (N, joiner_dim)
Returns:
Return a 2-D tensor of shape (N, vocab_size)
"""
logit = encoder_out + decoder_out
logit = self.output_linear(torch.tanh(logit))
return logit
def export_encoder_model_onnx(
encoder_model: OnnxEncoder,
encoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the given encoder model to ONNX format.
The exported model has two inputs:
- x, a tensor of shape (N, T, C); dtype is torch.float32
- x_lens, a tensor of shape (N,); dtype is torch.int64
and it has two outputs:
- encoder_out, a tensor of shape (N, T', joiner_dim)
- encoder_out_lens, a tensor of shape (N,)
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
x = torch.zeros(1, 100, 80, dtype=torch.float32)
x_lens = torch.tensor([100], dtype=torch.int64)
torch.onnx.export(
encoder_model,
(x, x_lens),
encoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "x_lens"],
output_names=["encoder_out", "encoder_out_lens"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"encoder_out": {0: "N", 1: "T"},
"encoder_out_lens": {0: "N"},
},
)
def export_decoder_model_onnx(
decoder_model: OnnxDecoder,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.
The exported model has one input:
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
and has one output:
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
context_size = decoder_model.decoder.context_size
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
torch.onnx.export(
decoder_model,
y,
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
)
meta_data = {
"context_size": str(context_size),
"vocab_size": str(vocab_size),
}
add_meta_data(filename=decoder_filename, meta_data=meta_data)
def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported joiner model has two inputs:
- encoder_out: a tensor of shape (N, joiner_dim)
- decoder_out: a tensor of shape (N, joiner_dim)
and produces one output:
- logit: a tensor of shape (N, vocab_size)
"""
joiner_dim = joiner_model.output_linear.weight.shape[1]
logging.info(f"joiner dim: {joiner_dim}")
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model,
(projected_encoder_out, projected_decoder_out),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"encoder_out",
"decoder_out",
],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
meta_data = {
"joiner_dim": str(joiner_dim),
}
add_meta_data(filename=joiner_filename, meta_data=meta_data)
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
setup_logger(f"{params.exp_dir}/log-export/log-export-onnx")
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
convert_scaled_to_non_scaled(model, inplace=True)
encoder = OnnxEncoder(
encoder=model.encoder,
encoder_proj=model.joiner.encoder_proj,
)
decoder = OnnxDecoder(
decoder=model.decoder,
decoder_proj=model.joiner.decoder_proj,
)
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
logging.info(f"encoder parameters: {encoder_num_param}")
logging.info(f"decoder parameters: {decoder_num_param}")
logging.info(f"joiner parameters: {joiner_num_param}")
logging.info(f"total parameters: {total_num_param}")
if params.iter > 0:
suffix = f"iter-{params.iter}"
else:
suffix = f"epoch-{params.epoch}"
suffix += f"-avg-{params.avg}"
opset_version = 13
logging.info("Exporting encoder")
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
export_encoder_model_onnx(
encoder,
encoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported encoder to {encoder_filename}")
logging.info("Exporting decoder")
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
export_decoder_model_onnx(
decoder,
decoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported decoder to {decoder_filename}")
logging.info("Exporting joiner")
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
export_joiner_model_onnx(
joiner,
joiner_filename,
opset_version=opset_version,
)
logging.info(f"Exported joiner to {joiner_filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
main()

View File

@ -41,31 +41,7 @@ Check
https://github.com/k2-fsa/sherpa https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall. for how to use the exported models outside of icefall.
(2) Export to ONNX format (2) Export `model.state_dict()`
./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10 \
--onnx 1
It will generate the following files in the given `exp_dir`.
Check `onnx_check.py` for how to use them.
- encoder.onnx
- decoder.onnx
- joiner.onnx
- joiner_encoder_proj.onnx
- joiner_decoder_proj.onnx
Please see ./onnx_pretrained.py for usage of the generated files
Check
https://github.com/k2-fsa/sherpa-onnx
for how to use the exported models outside of icefall.
(3) Export `model.state_dict()`
./pruned_transducer_stateless7/export.py \ ./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \ --exp-dir ./pruned_transducer_stateless7/exp \
@ -196,23 +172,6 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--onnx",
type=str2bool,
default=False,
help="""If True, --jit is ignored and it exports the model
to onnx format. It will generate the following files:
- encoder.onnx
- decoder.onnx
- joiner.onnx
- joiner_encoder_proj.onnx
- joiner_decoder_proj.onnx
Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them.
""",
)
parser.add_argument( parser.add_argument(
"--context-size", "--context-size",
type=int, type=int,
@ -225,204 +184,6 @@ def get_parser():
return parser return parser
def export_encoder_model_onnx(
encoder_model: nn.Module,
encoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the given encoder model to ONNX format.
The exported model has two inputs:
- x, a tensor of shape (N, T, C); dtype is torch.float32
- x_lens, a tensor of shape (N,); dtype is torch.int64
and it has two outputs:
- encoder_out, a tensor of shape (N, T, C)
- encoder_out_lens, a tensor of shape (N,)
Note: The warmup argument is fixed to 1.
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
x = torch.zeros(1, 101, 80, dtype=torch.float32)
x_lens = torch.tensor([101], dtype=torch.int64)
# encoder_model = torch.jit.script(encoder_model)
# It throws the following error for the above statement
#
# RuntimeError: Exporting the operator __is_ to ONNX opset version
# 11 is not supported. Please feel free to request support or
# submit a pull request on PyTorch GitHub.
#
# I cannot find which statement causes the above error.
# torch.onnx.export() will use torch.jit.trace() internally, which
# works well for the current reworked model
torch.onnx.export(
encoder_model,
(x, x_lens),
encoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "x_lens"],
output_names=["encoder_out", "encoder_out_lens"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"encoder_out": {0: "N", 1: "T"},
"encoder_out_lens": {0: "N"},
},
)
logging.info(f"Saved to {encoder_filename}")
def export_decoder_model_onnx(
decoder_model: nn.Module,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.
The exported model has one input:
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
and has one output:
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
Note: The argument need_pad is fixed to False.
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
need_pad = False # Always False, so we can use torch.jit.trace() here
# Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script()
# in this case
torch.onnx.export(
decoder_model,
(y, need_pad),
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y", "need_pad"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
)
logging.info(f"Saved to {decoder_filename}")
def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported joiner model has two inputs:
- projected_encoder_out: a tensor of shape (N, joiner_dim)
- projected_decoder_out: a tensor of shape (N, joiner_dim)
and produces one output:
- logit: a tensor of shape (N, vocab_size)
The exported encoder_proj model has one input:
- encoder_out: a tensor of shape (N, encoder_out_dim)
and produces one output:
- projected_encoder_out: a tensor of shape (N, joiner_dim)
The exported decoder_proj model has one input:
- decoder_out: a tensor of shape (N, decoder_out_dim)
and produces one output:
- projected_decoder_out: a tensor of shape (N, joiner_dim)
"""
encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
joiner_dim = joiner_model.decoder_proj.weight.shape[0]
projected_encoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32)
project_input = False
# Note: It uses torch.jit.trace() internally
torch.onnx.export(
joiner_model,
(projected_encoder_out, projected_decoder_out, project_input),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"encoder_out",
"decoder_out",
"project_input",
],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
logging.info(f"Saved to {joiner_filename}")
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model.encoder_proj,
encoder_out,
encoder_proj_filename,
verbose=False,
opset_version=opset_version,
input_names=["encoder_out"],
output_names=["projected_encoder_out"],
dynamic_axes={
"encoder_out": {0: "N"},
"projected_encoder_out": {0: "N"},
},
)
logging.info(f"Saved to {encoder_proj_filename}")
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model.decoder_proj,
decoder_out,
decoder_proj_filename,
verbose=False,
opset_version=opset_version,
input_names=["decoder_out"],
output_names=["projected_decoder_out"],
dynamic_axes={
"decoder_out": {0: "N"},
"projected_decoder_out": {0: "N"},
},
)
logging.info(f"Saved to {decoder_proj_filename}")
@torch.no_grad() @torch.no_grad()
def main(): def main():
args = get_parser().parse_args() args = get_parser().parse_args()
@ -531,31 +292,7 @@ def main():
model.to("cpu") model.to("cpu")
model.eval() model.eval()
if params.onnx is True: if params.jit is True:
convert_scaled_to_non_scaled(model, inplace=True)
opset_version = 13
logging.info("Exporting to onnx format")
encoder_filename = params.exp_dir / "encoder.onnx"
export_encoder_model_onnx(
model.encoder,
encoder_filename,
opset_version=opset_version,
)
decoder_filename = params.exp_dir / "decoder.onnx"
export_decoder_model_onnx(
model.decoder,
decoder_filename,
opset_version=opset_version,
)
joiner_filename = params.exp_dir / "joiner.onnx"
export_joiner_model_onnx(
model.joiner,
joiner_filename,
opset_version=opset_version,
)
elif params.jit is True:
convert_scaled_to_non_scaled(model, inplace=True) convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore # We won't use the forward() method of the model in C++, so just ignore
# it here. # it here.

View File

@ -1,286 +0,0 @@
#!/usr/bin/env python3
#
# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script checks that exported onnx models produce the same output
with the given torchscript model for the same input.
"""
import argparse
import logging
import onnxruntime as ort
import torch
from icefall import is_module_available
if not is_module_available("onnxruntime"):
raise ValueError("Please 'pip install onnxruntime' first.")
ort.set_default_logger_severity(3)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--jit-filename",
required=True,
type=str,
help="Path to the torchscript model",
)
parser.add_argument(
"--onnx-encoder-filename",
required=True,
type=str,
help="Path to the onnx encoder model",
)
parser.add_argument(
"--onnx-decoder-filename",
required=True,
type=str,
help="Path to the onnx decoder model",
)
parser.add_argument(
"--onnx-joiner-filename",
required=True,
type=str,
help="Path to the onnx joiner model",
)
parser.add_argument(
"--onnx-joiner-encoder-proj-filename",
required=True,
type=str,
help="Path to the onnx joiner encoder projection model",
)
parser.add_argument(
"--onnx-joiner-decoder-proj-filename",
required=True,
type=str,
help="Path to the onnx joiner decoder projection model",
)
return parser
def test_encoder(
model: torch.jit.ScriptModule,
encoder_session: ort.InferenceSession,
):
inputs = encoder_session.get_inputs()
outputs = encoder_session.get_outputs()
input_names = [n.name for n in inputs]
output_names = [n.name for n in outputs]
assert inputs[0].shape == ["N", "T", 80]
assert inputs[1].shape == ["N"]
for N in [1, 5]:
for T in [12, 50]:
print("N, T", N, T)
x = torch.rand(N, T, 80, dtype=torch.float32)
x_lens = torch.randint(low=10, high=T + 1, size=(N,))
x_lens[0] = T
encoder_inputs = {
input_names[0]: x.numpy(),
input_names[1]: x_lens.numpy(),
}
torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens)
encoder_out, encoder_out_lens = encoder_session.run(
output_names,
encoder_inputs,
)
torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens)
encoder_out = torch.from_numpy(encoder_out)
assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), (
(encoder_out - torch_encoder_out).abs().max(),
encoder_out.shape,
torch_encoder_out.shape,
)
def test_decoder(
model: torch.jit.ScriptModule,
decoder_session: ort.InferenceSession,
):
inputs = decoder_session.get_inputs()
outputs = decoder_session.get_outputs()
input_names = [n.name for n in inputs]
output_names = [n.name for n in outputs]
assert inputs[0].shape == ["N", 2]
for N in [1, 5, 10]:
y = torch.randint(low=1, high=500, size=(10, 2))
decoder_inputs = {input_names[0]: y.numpy()}
decoder_out = decoder_session.run(
output_names,
decoder_inputs,
)[0]
decoder_out = torch.from_numpy(decoder_out)
torch_decoder_out = model.decoder(y, need_pad=False)
assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), (
(decoder_out - torch_decoder_out).abs().max()
)
def test_joiner(
model: torch.jit.ScriptModule,
joiner_session: ort.InferenceSession,
joiner_encoder_proj_session: ort.InferenceSession,
joiner_decoder_proj_session: ort.InferenceSession,
):
joiner_inputs = joiner_session.get_inputs()
joiner_outputs = joiner_session.get_outputs()
joiner_input_names = [n.name for n in joiner_inputs]
joiner_output_names = [n.name for n in joiner_outputs]
assert joiner_inputs[0].shape == ["N", 1, 1, 512]
assert joiner_inputs[1].shape == ["N", 1, 1, 512]
joiner_encoder_proj_inputs = joiner_encoder_proj_session.get_inputs()
encoder_proj_input_name = joiner_encoder_proj_inputs[0].name
assert joiner_encoder_proj_inputs[0].shape == ["N", 384]
joiner_encoder_proj_outputs = joiner_encoder_proj_session.get_outputs()
encoder_proj_output_name = joiner_encoder_proj_outputs[0].name
joiner_decoder_proj_inputs = joiner_decoder_proj_session.get_inputs()
decoder_proj_input_name = joiner_decoder_proj_inputs[0].name
assert joiner_decoder_proj_inputs[0].shape == ["N", 512]
joiner_decoder_proj_outputs = joiner_decoder_proj_session.get_outputs()
decoder_proj_output_name = joiner_decoder_proj_outputs[0].name
for N in [1, 5, 10]:
encoder_out = torch.rand(N, 384)
decoder_out = torch.rand(N, 512)
projected_encoder_out = torch.rand(N, 1, 1, 512)
projected_decoder_out = torch.rand(N, 1, 1, 512)
joiner_inputs = {
joiner_input_names[0]: projected_encoder_out.numpy(),
joiner_input_names[1]: projected_decoder_out.numpy(),
}
joiner_out = joiner_session.run(joiner_output_names, joiner_inputs)[0]
joiner_out = torch.from_numpy(joiner_out)
torch_joiner_out = model.joiner(
projected_encoder_out,
projected_decoder_out,
project_input=False,
)
assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), (
(joiner_out - torch_joiner_out).abs().max()
)
# Now test encoder_proj
joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()}
joiner_encoder_proj_out = joiner_encoder_proj_session.run(
[encoder_proj_output_name], joiner_encoder_proj_inputs
)[0]
joiner_encoder_proj_out = torch.from_numpy(joiner_encoder_proj_out)
torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out)
assert torch.allclose(
joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5
), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max())
# Now test decoder_proj
joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()}
joiner_decoder_proj_out = joiner_decoder_proj_session.run(
[decoder_proj_output_name], joiner_decoder_proj_inputs
)[0]
joiner_decoder_proj_out = torch.from_numpy(joiner_decoder_proj_out)
torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out)
assert torch.allclose(
joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5
), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max())
@torch.no_grad()
def main():
args = get_parser().parse_args()
logging.info(vars(args))
model = torch.jit.load(args.jit_filename)
options = ort.SessionOptions()
options.inter_op_num_threads = 1
options.intra_op_num_threads = 1
logging.info("Test encoder")
encoder_session = ort.InferenceSession(
args.onnx_encoder_filename,
sess_options=options,
)
test_encoder(model, encoder_session)
logging.info("Test decoder")
decoder_session = ort.InferenceSession(
args.onnx_decoder_filename,
sess_options=options,
)
test_decoder(model, decoder_session)
logging.info("Test joiner")
joiner_session = ort.InferenceSession(
args.onnx_joiner_filename,
sess_options=options,
)
joiner_encoder_proj_session = ort.InferenceSession(
args.onnx_joiner_encoder_proj_filename,
sess_options=options,
)
joiner_decoder_proj_session = ort.InferenceSession(
args.onnx_joiner_decoder_proj_filename,
sess_options=options,
)
test_joiner(
model,
joiner_session,
joiner_encoder_proj_session,
joiner_decoder_proj_session,
)
logging.info("Finished checking ONNX models")
if __name__ == "__main__":
torch.manual_seed(20220727)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless5/onnx_check.py

View File

@ -1,388 +0,0 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads ONNX models and uses them to decode waves.
You can use the following command to get the exported models:
./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10 \
--onnx 1
Usage of this script:
./pruned_transducer_stateless7/onnx_pretrained.py \
--encoder-model-filename ./pruned_transducer_stateless7/exp/encoder.onnx \
--decoder-model-filename ./pruned_transducer_stateless7/exp/decoder.onnx \
--joiner-model-filename ./pruned_transducer_stateless7/exp/joiner.onnx \
--joiner-encoder-proj-model-filename ./pruned_transducer_stateless7/exp/joiner_encoder_proj.onnx \
--joiner-decoder-proj-model-filename ./pruned_transducer_stateless7/exp/joiner_decoder_proj.onnx \
--bpe-model ./data/lang_bpe_500/bpe.model \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
import logging
import math
from typing import List
import kaldifeat
import numpy as np
import onnxruntime as ort
import sentencepiece as spm
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder onnx model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder onnx model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner onnx model. ",
)
parser.add_argument(
"--joiner-encoder-proj-model-filename",
type=str,
required=True,
help="Path to the joiner encoder_proj onnx model. ",
)
parser.add_argument(
"--joiner-decoder-proj-model-filename",
type=str,
required=True,
help="Path to the joiner decoder_proj onnx model. ",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="Context size of the decoder model",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
def greedy_search(
decoder: ort.InferenceSession,
joiner: ort.InferenceSession,
joiner_encoder_proj: ort.InferenceSession,
joiner_decoder_proj: ort.InferenceSession,
encoder_out: np.ndarray,
encoder_out_lens: np.ndarray,
context_size: int,
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
decoder:
The decoder model.
joiner:
The joiner model.
joiner_encoder_proj:
The joiner encoder projection model.
joiner_decoder_proj:
The joiner decoder projection model.
encoder_out:
A 3-D tensor of shape (N, T, C)
encoder_out_lens:
A 1-D tensor of shape (N,).
context_size:
The context size of the decoder model.
Returns:
Return the decoded results for each utterance.
"""
encoder_out = torch.from_numpy(encoder_out)
encoder_out_lens = torch.from_numpy(encoder_out_lens)
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
projected_encoder_out = joiner_encoder_proj.run(
[joiner_encoder_proj.get_outputs()[0].name],
{joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()},
)[0]
blank_id = 0 # hard-code to 0
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
hyps = [[blank_id] * context_size for _ in range(N)]
decoder_input_nodes = decoder.get_inputs()
decoder_output_nodes = decoder.get_outputs()
joiner_input_nodes = joiner.get_inputs()
joiner_output_nodes = joiner.get_outputs()
decoder_input = torch.tensor(
hyps,
dtype=torch.int64,
) # (N, context_size)
decoder_out = decoder.run(
[decoder_output_nodes[0].name],
{
decoder_input_nodes[0].name: decoder_input.numpy(),
},
)[0].squeeze(1)
projected_decoder_out = joiner_decoder_proj.run(
[joiner_decoder_proj.get_outputs()[0].name],
{joiner_decoder_proj.get_inputs()[0].name: decoder_out},
)[0]
projected_decoder_out = torch.from_numpy(projected_decoder_out)
offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = projected_encoder_out[start:end]
# current_encoder_out's shape: (batch_size, encoder_out_dim)
offset = end
projected_decoder_out = projected_decoder_out[:batch_size]
logits = joiner.run(
[joiner_output_nodes[0].name],
{
joiner_input_nodes[0].name: np.expand_dims(
np.expand_dims(current_encoder_out, axis=1), axis=1
),
joiner_input_nodes[1]
.name: projected_decoder_out.unsqueeze(1)
.unsqueeze(1)
.numpy(),
},
)[0]
logits = torch.from_numpy(logits).squeeze(1).squeeze(1)
# logits'shape (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor(
decoder_input,
dtype=torch.int64,
)
decoder_out = decoder.run(
[decoder_output_nodes[0].name],
{
decoder_input_nodes[0].name: decoder_input.numpy(),
},
)[0].squeeze(1)
projected_decoder_out = joiner_decoder_proj.run(
[joiner_decoder_proj.get_outputs()[0].name],
{joiner_decoder_proj.get_inputs()[0].name: decoder_out},
)[0]
projected_decoder_out = torch.from_numpy(projected_decoder_out)
sorted_ans = [h[context_size:] for h in hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
encoder = ort.InferenceSession(
args.encoder_model_filename,
sess_options=session_opts,
)
decoder = ort.InferenceSession(
args.decoder_model_filename,
sess_options=session_opts,
)
joiner = ort.InferenceSession(
args.joiner_model_filename,
sess_options=session_opts,
)
joiner_encoder_proj = ort.InferenceSession(
args.joiner_encoder_proj_model_filename,
sess_options=session_opts,
)
joiner_decoder_proj = ort.InferenceSession(
args.joiner_decoder_proj_model_filename,
sess_options=session_opts,
)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files,
expected_sample_rate=args.sample_rate,
)
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10),
)
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
encoder_input_nodes = encoder.get_inputs()
encoder_out_nodes = encoder.get_outputs()
encoder_out, encoder_out_lens = encoder.run(
[encoder_out_nodes[0].name, encoder_out_nodes[1].name],
{
encoder_input_nodes[0].name: features.numpy(),
encoder_input_nodes[1].name: feature_lengths.numpy(),
},
)
hyps = greedy_search(
decoder=decoder,
joiner=joiner,
joiner_encoder_proj=joiner_encoder_proj,
joiner_decoder_proj=joiner_decoder_proj,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
context_size=args.context_size,
)
s = "\n"
for filename, hyp in zip(args.sound_files, hyps):
words = sp.decode(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless3/onnx_pretrained.py

View File

@ -44,7 +44,7 @@ from scaling import (
from torch import Tensor, nn from torch import Tensor, nn
from icefall.dist import get_rank from icefall.dist import get_rank
from icefall.utils import make_pad_mask from icefall.utils import is_jit_tracing, make_pad_mask
class Zipformer(EncoderInterface): class Zipformer(EncoderInterface):
@ -792,6 +792,7 @@ class AttentionDownsample(torch.nn.Module):
src = src.reshape(d_seq_len, ds, batch_size, in_channels) src = src.reshape(d_seq_len, ds, batch_size, in_channels)
scores = (src * self.query).sum(dim=-1, keepdim=True) scores = (src * self.query).sum(dim=-1, keepdim=True)
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04)
weights = scores.softmax(dim=1) weights = scores.softmax(dim=1)
@ -904,6 +905,13 @@ class RelPositionalEncoding(torch.nn.Module):
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
"""Construct a PositionalEncoding object.""" """Construct a PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__() super(RelPositionalEncoding, self).__init__()
if is_jit_tracing():
# 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e.,
# It assumes that the maximum input won't have more than
# 10k frames.
#
# TODO(fangjun): Use torch.jit.script() for this module
max_len = 10000
self.d_model = d_model self.d_model = d_model
self.dropout = torch.nn.Dropout(dropout_rate) self.dropout = torch.nn.Dropout(dropout_rate)
self.pe = None self.pe = None

View File

@ -306,11 +306,11 @@ def export_encoder_model_onnx(
left_context_len = ",".join(map(str, left_context_len)) left_context_len = ",".join(map(str, left_context_len))
meta_data = { meta_data = {
"model_type": "streaming_zipformer", "model_type": "zipformer",
"version": "1", "version": "1",
"model_author": "k2-fsa", "model_author": "k2-fsa",
"decode_chunk_len": str(decode_chunk_len), # 32 "decode_chunk_len": str(decode_chunk_len), # 32
"pad_length": str(pad_length), # 7 "T": str(T), # 39
"num_encoder_layers": num_encoder_layers, "num_encoder_layers": num_encoder_layers,
"encoder_dims": encoder_dims, "encoder_dims": encoder_dims,
"attention_dims": attention_dims, "attention_dims": attention_dims,
@ -362,8 +362,8 @@ def export_encoder_model_onnx(
input_names=input_names, input_names=input_names,
output_names=output_names, output_names=output_names,
dynamic_axes={ dynamic_axes={
"x": {0: "N", 1: "T"}, "x": {0: "N"},
"encoder_out": {0: "N", 1: "T"}, "encoder_out": {0: "N"},
**inputs, **inputs,
**outputs, **outputs,
}, },

View File

@ -136,8 +136,11 @@ class OnnxModel:
def init_encoder_states(self, batch_size: int = 1): def init_encoder_states(self, batch_size: int = 1):
encoder_meta = self.encoder.get_modelmeta().custom_metadata_map encoder_meta = self.encoder.get_modelmeta().custom_metadata_map
model_type = encoder_meta["model_type"]
assert model_type == "zipformer", model_type
decode_chunk_len = int(encoder_meta["decode_chunk_len"]) decode_chunk_len = int(encoder_meta["decode_chunk_len"])
pad_length = int(encoder_meta["pad_length"]) T = int(encoder_meta["T"])
num_encoder_layers = encoder_meta["num_encoder_layers"] num_encoder_layers = encoder_meta["num_encoder_layers"]
encoder_dims = encoder_meta["encoder_dims"] encoder_dims = encoder_meta["encoder_dims"]
@ -155,7 +158,7 @@ class OnnxModel:
left_context_len = to_int_list(left_context_len) left_context_len = to_int_list(left_context_len)
logging.info(f"decode_chunk_len: {decode_chunk_len}") logging.info(f"decode_chunk_len: {decode_chunk_len}")
logging.info(f"pad_length: {pad_length}") logging.info(f"T: {T}")
logging.info(f"num_encoder_layers: {num_encoder_layers}") logging.info(f"num_encoder_layers: {num_encoder_layers}")
logging.info(f"encoder_dims: {encoder_dims}") logging.info(f"encoder_dims: {encoder_dims}")
logging.info(f"attention_dims: {attention_dims}") logging.info(f"attention_dims: {attention_dims}")
@ -219,7 +222,7 @@ class OnnxModel:
self.num_encoders = num_encoders self.num_encoders = num_encoders
self.segment = decode_chunk_len + pad_length self.segment = T
self.offset = decode_chunk_len self.offset = decode_chunk_len
def init_decoder(self, decoder_model_filename: str): def init_decoder(self, decoder_model_filename: str):