mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Add ONNX support for Zipformer and ConvEmformer (#884)
This commit is contained in:
parent
af735eb75b
commit
2b995639b7
@ -1,79 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
#
|
|
||||||
set -e
|
|
||||||
|
|
||||||
log() {
|
|
||||||
# This function is from espnet
|
|
||||||
local fname=${BASH_SOURCE[1]##*/}
|
|
||||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
|
||||||
}
|
|
||||||
|
|
||||||
cd egs/librispeech/ASR
|
|
||||||
|
|
||||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05
|
|
||||||
|
|
||||||
log "Downloading pre-trained model from $repo_url"
|
|
||||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
|
||||||
repo=$(basename $repo_url)
|
|
||||||
pushd $repo
|
|
||||||
git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt"
|
|
||||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
|
||||||
cd exp
|
|
||||||
ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-99.pt
|
|
||||||
popd
|
|
||||||
|
|
||||||
log "Display test files"
|
|
||||||
tree $repo/
|
|
||||||
soxi $repo/test_wavs/*.wav
|
|
||||||
ls -lh $repo/test_wavs/*.wav
|
|
||||||
|
|
||||||
log "Install ncnn and pnnx"
|
|
||||||
|
|
||||||
# We are using a modified ncnn here. Will try to merge it to the official repo
|
|
||||||
# of ncnn
|
|
||||||
git clone https://github.com/csukuangfj/ncnn
|
|
||||||
pushd ncnn
|
|
||||||
git submodule init
|
|
||||||
git submodule update python/pybind11
|
|
||||||
python3 setup.py bdist_wheel
|
|
||||||
ls -lh dist/
|
|
||||||
pip install dist/*.whl
|
|
||||||
cd tools/pnnx
|
|
||||||
mkdir build
|
|
||||||
cd build
|
|
||||||
cmake -D Python3_EXECUTABLE=/opt/hostedtoolcache/Python/3.8.14/x64/bin/python3 ..
|
|
||||||
make -j4 pnnx
|
|
||||||
|
|
||||||
./src/pnnx || echo "pass"
|
|
||||||
|
|
||||||
popd
|
|
||||||
|
|
||||||
log "Test exporting to pnnx format"
|
|
||||||
|
|
||||||
./conv_emformer_transducer_stateless2/export-for-ncnn.py \
|
|
||||||
--exp-dir $repo/exp \
|
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
|
||||||
--epoch 99 \
|
|
||||||
--avg 1 \
|
|
||||||
--use-averaged-model 0 \
|
|
||||||
\
|
|
||||||
--num-encoder-layers 12 \
|
|
||||||
--chunk-length 32 \
|
|
||||||
--cnn-module-kernel 31 \
|
|
||||||
--left-context-length 32 \
|
|
||||||
--right-context-length 8 \
|
|
||||||
--memory-size 32
|
|
||||||
|
|
||||||
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt
|
|
||||||
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt
|
|
||||||
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt
|
|
||||||
|
|
||||||
./conv_emformer_transducer_stateless2/streaming-ncnn-decode.py \
|
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
|
||||||
--encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
|
|
||||||
--encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
|
|
||||||
--decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
|
|
||||||
--decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
|
|
||||||
--joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
|
|
||||||
--joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
|
|
||||||
$repo/test_wavs/1089-134686-0001.wav
|
|
@ -28,63 +28,6 @@ ln -s pretrained-iter-468000-avg-16.pt pretrained.pt
|
|||||||
ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt
|
ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt
|
||||||
popd
|
popd
|
||||||
|
|
||||||
log "Install ncnn and pnnx"
|
|
||||||
|
|
||||||
# We are using a modified ncnn here. Will try to merge it to the official repo
|
|
||||||
# of ncnn
|
|
||||||
git clone https://github.com/csukuangfj/ncnn
|
|
||||||
pushd ncnn
|
|
||||||
git submodule init
|
|
||||||
git submodule update python/pybind11
|
|
||||||
python3 setup.py bdist_wheel
|
|
||||||
ls -lh dist/
|
|
||||||
pip install dist/*.whl
|
|
||||||
cd tools/pnnx
|
|
||||||
mkdir build
|
|
||||||
cd build
|
|
||||||
cmake ..
|
|
||||||
make -j4 pnnx
|
|
||||||
|
|
||||||
./src/pnnx || echo "pass"
|
|
||||||
|
|
||||||
popd
|
|
||||||
|
|
||||||
log "Test exporting to pnnx format"
|
|
||||||
|
|
||||||
./lstm_transducer_stateless2/export.py \
|
|
||||||
--exp-dir $repo/exp \
|
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
|
||||||
--epoch 99 \
|
|
||||||
--avg 1 \
|
|
||||||
--use-averaged-model 0 \
|
|
||||||
--pnnx 1
|
|
||||||
|
|
||||||
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt
|
|
||||||
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt
|
|
||||||
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt
|
|
||||||
|
|
||||||
./lstm_transducer_stateless2/ncnn-decode.py \
|
|
||||||
--bpe-model-filename $repo/data/lang_bpe_500/bpe.model \
|
|
||||||
--encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
|
|
||||||
--encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
|
|
||||||
--decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
|
|
||||||
--decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
|
|
||||||
--joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
|
|
||||||
--joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
|
|
||||||
$repo/test_wavs/1089-134686-0001.wav
|
|
||||||
|
|
||||||
./lstm_transducer_stateless2/streaming-ncnn-decode.py \
|
|
||||||
--bpe-model-filename $repo/data/lang_bpe_500/bpe.model \
|
|
||||||
--encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
|
|
||||||
--encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
|
|
||||||
--decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
|
|
||||||
--decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
|
|
||||||
--joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
|
|
||||||
--joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
|
|
||||||
$repo/test_wavs/1089-134686-0001.wav
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
log "Test exporting with torch.jit.trace()"
|
log "Test exporting with torch.jit.trace()"
|
||||||
|
|
||||||
./lstm_transducer_stateless2/export.py \
|
./lstm_transducer_stateless2/export.py \
|
||||||
@ -106,47 +49,6 @@ log "Decode with models exported by torch.jit.trace()"
|
|||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
|
||||||
log "Test exporting to ONNX"
|
|
||||||
|
|
||||||
./lstm_transducer_stateless2/export.py \
|
|
||||||
--exp-dir $repo/exp \
|
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
|
||||||
--epoch 99 \
|
|
||||||
--avg 1 \
|
|
||||||
--use-averaged-model 0 \
|
|
||||||
--onnx 1
|
|
||||||
|
|
||||||
log "Decode with ONNX models "
|
|
||||||
|
|
||||||
./lstm_transducer_stateless2/streaming-onnx-decode.py \
|
|
||||||
--bpe-model-filename $repo/data/lang_bpe_500/bpe.model \
|
|
||||||
--encoder-model-filename $repo//exp/encoder.onnx \
|
|
||||||
--decoder-model-filename $repo/exp/decoder.onnx \
|
|
||||||
--joiner-model-filename $repo/exp/joiner.onnx \
|
|
||||||
--joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \
|
|
||||||
--joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \
|
|
||||||
$repo/test_wavs/1089-134686-0001.wav
|
|
||||||
|
|
||||||
./lstm_transducer_stateless2/streaming-onnx-decode.py \
|
|
||||||
--bpe-model-filename $repo/data/lang_bpe_500/bpe.model \
|
|
||||||
--encoder-model-filename $repo//exp/encoder.onnx \
|
|
||||||
--decoder-model-filename $repo/exp/decoder.onnx \
|
|
||||||
--joiner-model-filename $repo/exp/joiner.onnx \
|
|
||||||
--joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \
|
|
||||||
--joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \
|
|
||||||
$repo/test_wavs/1221-135766-0001.wav
|
|
||||||
|
|
||||||
./lstm_transducer_stateless2/streaming-onnx-decode.py \
|
|
||||||
--bpe-model-filename $repo/data/lang_bpe_500/bpe.model \
|
|
||||||
--encoder-model-filename $repo//exp/encoder.onnx \
|
|
||||||
--decoder-model-filename $repo/exp/decoder.onnx \
|
|
||||||
--joiner-model-filename $repo/exp/joiner.onnx \
|
|
||||||
--joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \
|
|
||||||
--joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \
|
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
for sym in 1 2 3; do
|
for sym in 1 2 3; do
|
||||||
log "Greedy search with --max-sym-per-frame $sym"
|
log "Greedy search with --max-sym-per-frame $sym"
|
||||||
|
|
||||||
|
@ -30,15 +30,6 @@ ln -s pretrained.pt epoch-99.pt
|
|||||||
ls -lh *.pt
|
ls -lh *.pt
|
||||||
popd
|
popd
|
||||||
|
|
||||||
log "Test exporting to ONNX format"
|
|
||||||
./pruned_transducer_stateless7/export.py \
|
|
||||||
--exp-dir $repo/exp \
|
|
||||||
--use-averaged-model false \
|
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
|
||||||
--epoch 99 \
|
|
||||||
--avg 1 \
|
|
||||||
--onnx 1
|
|
||||||
|
|
||||||
log "Export to torchscript model"
|
log "Export to torchscript model"
|
||||||
./pruned_transducer_stateless7/export.py \
|
./pruned_transducer_stateless7/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
@ -50,27 +41,6 @@ log "Export to torchscript model"
|
|||||||
|
|
||||||
ls -lh $repo/exp/*.pt
|
ls -lh $repo/exp/*.pt
|
||||||
|
|
||||||
log "Decode with ONNX models"
|
|
||||||
|
|
||||||
./pruned_transducer_stateless7/onnx_check.py \
|
|
||||||
--jit-filename $repo/exp/cpu_jit.pt \
|
|
||||||
--onnx-encoder-filename $repo/exp/encoder.onnx \
|
|
||||||
--onnx-decoder-filename $repo/exp/decoder.onnx \
|
|
||||||
--onnx-joiner-filename $repo/exp/joiner.onnx \
|
|
||||||
--onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \
|
|
||||||
--onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx
|
|
||||||
|
|
||||||
./pruned_transducer_stateless7/onnx_pretrained.py \
|
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
|
||||||
--encoder-model-filename $repo/exp/encoder.onnx \
|
|
||||||
--decoder-model-filename $repo/exp/decoder.onnx \
|
|
||||||
--joiner-model-filename $repo/exp/joiner.onnx \
|
|
||||||
--joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \
|
|
||||||
--joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \
|
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
|
||||||
|
|
||||||
log "Decode with models exported by torch.jit.script()"
|
log "Decode with models exported by torch.jit.script()"
|
||||||
|
|
||||||
./pruned_transducer_stateless7/jit_pretrained.py \
|
./pruned_transducer_stateless7/jit_pretrained.py \
|
||||||
|
@ -34,16 +34,6 @@ ln -s pretrained.pt epoch-99.pt
|
|||||||
ls -lh *.pt
|
ls -lh *.pt
|
||||||
popd
|
popd
|
||||||
|
|
||||||
log "Test exporting to ONNX format"
|
|
||||||
./pruned_transducer_stateless7_streaming/export.py \
|
|
||||||
--exp-dir $repo/exp \
|
|
||||||
--use-averaged-model false \
|
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
|
||||||
--epoch 99 \
|
|
||||||
--avg 1 \
|
|
||||||
--fp16 \
|
|
||||||
--onnx 1
|
|
||||||
|
|
||||||
log "Export to torchscript model"
|
log "Export to torchscript model"
|
||||||
./pruned_transducer_stateless7_streaming/export.py \
|
./pruned_transducer_stateless7_streaming/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
|
133
.github/scripts/test-ncnn-export.sh
vendored
Executable file
133
.github/scripts/test-ncnn-export.sh
vendored
Executable file
@ -0,0 +1,133 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
|
||||||
|
log "Install ncnn and pnnx"
|
||||||
|
|
||||||
|
# We are using a modified ncnn here. Will try to merge it to the official repo
|
||||||
|
# of ncnn
|
||||||
|
git clone https://github.com/csukuangfj/ncnn
|
||||||
|
pushd ncnn
|
||||||
|
git submodule init
|
||||||
|
git submodule update python/pybind11
|
||||||
|
python3 setup.py bdist_wheel
|
||||||
|
ls -lh dist/
|
||||||
|
pip install dist/*.whl
|
||||||
|
cd tools/pnnx
|
||||||
|
mkdir build
|
||||||
|
cd build
|
||||||
|
|
||||||
|
echo "which python3"
|
||||||
|
|
||||||
|
which python3
|
||||||
|
#/opt/hostedtoolcache/Python/3.8.16/x64/bin/python3
|
||||||
|
|
||||||
|
cmake -D Python3_EXECUTABLE=$(which python3) ..
|
||||||
|
make -j4 pnnx
|
||||||
|
|
||||||
|
./src/pnnx || echo "pass"
|
||||||
|
|
||||||
|
popd
|
||||||
|
|
||||||
|
log "=========================================================================="
|
||||||
|
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||||
|
git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt"
|
||||||
|
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
log "Export via torch.jit.trace()"
|
||||||
|
|
||||||
|
./conv_emformer_transducer_stateless2/export-for-ncnn.py \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
\
|
||||||
|
--num-encoder-layers 12 \
|
||||||
|
--chunk-length 32 \
|
||||||
|
--cnn-module-kernel 31 \
|
||||||
|
--left-context-length 32 \
|
||||||
|
--right-context-length 8 \
|
||||||
|
--memory-size 32
|
||||||
|
|
||||||
|
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt
|
||||||
|
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt
|
||||||
|
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt
|
||||||
|
|
||||||
|
python3 ./conv_emformer_transducer_stateless2/streaming-ncnn-decode.py \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
--encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
|
||||||
|
--encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
|
||||||
|
--decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
|
||||||
|
--decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
|
||||||
|
--joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
|
||||||
|
--joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
|
log "--------------------------------------------------------------------------"
|
||||||
|
|
||||||
|
log "=========================================================================="
|
||||||
|
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||||
|
git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt"
|
||||||
|
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
log "Export via torch.jit.trace()"
|
||||||
|
|
||||||
|
./lstm_transducer_stateless2/export-for-ncnn.py \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--use-averaged-model 0
|
||||||
|
|
||||||
|
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt
|
||||||
|
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt
|
||||||
|
./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt
|
||||||
|
|
||||||
|
python3 ./lstm_transducer_stateless2/streaming-ncnn-decode.py \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
--encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
|
||||||
|
--encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
|
||||||
|
--decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
|
||||||
|
--decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
|
||||||
|
--joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
|
||||||
|
--joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav
|
||||||
|
|
||||||
|
python3 ./lstm_transducer_stateless2/ncnn-decode.py \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
--encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
|
||||||
|
--encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
|
||||||
|
--decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
|
||||||
|
--decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
|
||||||
|
--joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
|
||||||
|
--joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
|
log "--------------------------------------------------------------------------"
|
157
.github/scripts/test-onnx-export.sh
vendored
157
.github/scripts/test-onnx-export.sh
vendored
@ -10,6 +10,8 @@ log() {
|
|||||||
|
|
||||||
cd egs/librispeech/ASR
|
cd egs/librispeech/ASR
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
log "=========================================================================="
|
log "=========================================================================="
|
||||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||||
log "Downloading pre-trained model from $repo_url"
|
log "Downloading pre-trained model from $repo_url"
|
||||||
@ -192,3 +194,158 @@ log "Run onnx_pretrained.py"
|
|||||||
|
|
||||||
rm -rf $repo
|
rm -rf $repo
|
||||||
log "--------------------------------------------------------------------------"
|
log "--------------------------------------------------------------------------"
|
||||||
|
|
||||||
|
log "=========================================================================="
|
||||||
|
repo_url=
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
|
log "--------------------------------------------------------------------------"
|
||||||
|
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||||
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
log "Export via torch.jit.script()"
|
||||||
|
|
||||||
|
./pruned_transducer_stateless7/export.py \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--feedforward-dims "1024,1024,2048,2048,1024" \
|
||||||
|
--jit 1
|
||||||
|
|
||||||
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
|
./pruned_transducer_stateless7/export-onnx.py \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--feedforward-dims "1024,1024,2048,2048,1024"
|
||||||
|
|
||||||
|
ls -lh $repo/exp
|
||||||
|
|
||||||
|
log "Run onnx_check.py"
|
||||||
|
|
||||||
|
./pruned_transducer_stateless7/onnx_check.py \
|
||||||
|
--jit-filename $repo/exp/cpu_jit.pt \
|
||||||
|
--onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||||
|
--onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||||
|
--onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx
|
||||||
|
|
||||||
|
log "Run onnx_pretrained.py"
|
||||||
|
|
||||||
|
./pruned_transducer_stateless7/onnx_pretrained.py \
|
||||||
|
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||||
|
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||||
|
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
|
||||||
|
log "=========================================================================="
|
||||||
|
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||||
|
git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt"
|
||||||
|
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
|
./conv_emformer_transducer_stateless2/export-onnx.py \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--num-encoder-layers 12 \
|
||||||
|
--chunk-length 32 \
|
||||||
|
--cnn-module-kernel 31 \
|
||||||
|
--left-context-length 32 \
|
||||||
|
--right-context-length 8 \
|
||||||
|
--memory-size 32
|
||||||
|
|
||||||
|
log "Run onnx_pretrained.py"
|
||||||
|
|
||||||
|
./conv_emformer_transducer_stateless2/onnx_pretrained.py \
|
||||||
|
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||||
|
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||||
|
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
$repo/test_wavs/1221-135766-0001.wav
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
|
log "--------------------------------------------------------------------------"
|
||||||
|
|
||||||
|
log "=========================================================================="
|
||||||
|
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||||
|
git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt"
|
||||||
|
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
log "Export via torch.jit.trace()"
|
||||||
|
|
||||||
|
./lstm_transducer_stateless2/export.py \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir $repo/exp/ \
|
||||||
|
--jit-trace 1
|
||||||
|
|
||||||
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
|
./lstm_transducer_stateless2/export-onnx.py \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir $repo/exp
|
||||||
|
|
||||||
|
ls -lh $repo/exp
|
||||||
|
|
||||||
|
log "Run onnx_check.py"
|
||||||
|
|
||||||
|
./lstm_transducer_stateless2/onnx_check.py \
|
||||||
|
--jit-encoder-filename $repo/exp/encoder_jit_trace.pt \
|
||||||
|
--jit-decoder-filename $repo/exp/decoder_jit_trace.pt \
|
||||||
|
--jit-joiner-filename $repo/exp/joiner_jit_trace.pt \
|
||||||
|
--onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||||
|
--onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||||
|
--onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx
|
||||||
|
|
||||||
|
log "Run onnx_pretrained.py"
|
||||||
|
|
||||||
|
./lstm_transducer_stateless2/onnx_pretrained.py \
|
||||||
|
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||||
|
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||||
|
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
$repo/test_wavs/1221-135766-0001.wav
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
|
log "--------------------------------------------------------------------------"
|
||||||
|
@ -39,7 +39,7 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
run_librispeech_2022_11_11_zipformer:
|
run_librispeech_2022_11_11_zipformer:
|
||||||
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
|
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
|
@ -39,7 +39,7 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
run_librispeech_2022_12_29_zipformer_streaming:
|
run_librispeech_2022_12_29_zipformer_streaming:
|
||||||
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'streaming-zipformer' || github.event_name == 'push' || github.event_name == 'schedule'
|
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'streaming-zipformer' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
|
@ -22,7 +22,7 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
run_librispeech_lstm_transducer_stateless2_2022_09_03:
|
run_librispeech_lstm_transducer_stateless2_2022_09_03:
|
||||||
if: github.event.label.name == 'ready' || github.event.label.name == 'LODR' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'ncnn' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule'
|
if: github.event.label.name == 'ready' || github.event.label.name == 'LODR' || github.event.label.name == 'shallow-fusion' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
name: run-librispeech-conv-emformer-transducer-stateless2-2022-12-05
|
name: test-ncnn-export
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
@ -16,15 +16,18 @@ on:
|
|||||||
# nightly build at 15:50 UTC time every day
|
# nightly build at 15:50 UTC time every day
|
||||||
- cron: "50 15 * * *"
|
- cron: "50 15 * * *"
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: test_ncnn_export-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
run_librispeech_conv_emformer_transducer_stateless2_2022_12_05:
|
test_ncnn_export:
|
||||||
if: github.event.label.name == 'ready' || github.event.label.name == 'ncnn' || github.event_name == 'push' || github.event_name == 'schedule'
|
if: github.event.label.name == 'ready' || github.event.label.name == 'ncnn' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.8]
|
python-version: [3.8]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@ -41,7 +44,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install Python dependencies
|
- name: Install Python dependencies
|
||||||
run: |
|
run: |
|
||||||
grep -v '^#' ./requirements-ci.txt | grep -v kaldifst | xargs -n 1 -L 1 pip install
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf
|
||||||
|
|
||||||
@ -59,19 +62,14 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
.github/scripts/install-kaldifeat.sh
|
.github/scripts/install-kaldifeat.sh
|
||||||
|
|
||||||
- name: Inference with pre-trained model
|
- name: Test ncnn export
|
||||||
shell: bash
|
shell: bash
|
||||||
env:
|
env:
|
||||||
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
||||||
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
|
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
|
||||||
run: |
|
run: |
|
||||||
mkdir -p egs/librispeech/ASR/data
|
|
||||||
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
|
|
||||||
ls -lh egs/librispeech/ASR/data/*
|
|
||||||
|
|
||||||
sudo apt-get -qq install git-lfs tree sox
|
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|
||||||
.github/scripts/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.sh
|
.github/scripts/test-ncnn-export.sh
|
@ -1,69 +1,78 @@
|
|||||||
Export to ONNX
|
Export to ONNX
|
||||||
==============
|
==============
|
||||||
|
|
||||||
In this section, we describe how to export models to ONNX.
|
In this section, we describe how to export the following models to ONNX.
|
||||||
|
|
||||||
|
In each recipe, there is a file called ``export-onnx.py``, which is used
|
||||||
|
to export trained models to ONNX.
|
||||||
|
|
||||||
|
There is also a file named ``onnx_pretrained.py``, which you can use
|
||||||
|
the exported ONNX model in Python to decode sound files.
|
||||||
|
|
||||||
|
Example
|
||||||
|
=======
|
||||||
|
|
||||||
|
In the following, we demonstrate how to export a streaming Zipformer pre-trained
|
||||||
|
model from `<python3 ./python-api-examples/speech-recognition-from-microphone.py>`_
|
||||||
|
to ONNX.
|
||||||
|
|
||||||
|
Download the pre-trained model
|
||||||
|
------------------------------
|
||||||
|
|
||||||
.. hint::
|
.. hint::
|
||||||
|
|
||||||
Only non-streaming conformer transducer models are tested.
|
We assume you have installed `git-lfs`_.
|
||||||
|
|
||||||
|
|
||||||
When to use it
|
|
||||||
--------------
|
|
||||||
|
|
||||||
It you want to use an inference framework that supports ONNX
|
|
||||||
to run the pretrained model.
|
|
||||||
|
|
||||||
|
|
||||||
How to export
|
|
||||||
-------------
|
|
||||||
|
|
||||||
We use
|
|
||||||
`<https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless3>`_
|
|
||||||
as an example in the following.
|
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
|
|
||||||
cd egs/librispeech/ASR
|
cd egs/librispeech/ASR
|
||||||
epoch=14
|
|
||||||
avg=2
|
|
||||||
|
|
||||||
./pruned_transducer_stateless3/export.py \
|
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
repo=$(basename $repo_url)
|
||||||
--epoch $epoch \
|
|
||||||
--avg $avg \
|
|
||||||
--onnx 1
|
|
||||||
|
|
||||||
It will generate the following files inside ``pruned_transducer_stateless3/exp``:
|
pushd $repo
|
||||||
|
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||||
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
- ``encoder.onnx``
|
Export the model to ONNX
|
||||||
- ``decoder.onnx``
|
------------------------
|
||||||
- ``joiner.onnx``
|
|
||||||
- ``joiner_encoder_proj.onnx``
|
|
||||||
- ``joiner_decoder_proj.onnx``
|
|
||||||
|
|
||||||
You can use ``./pruned_transducer_stateless3/exp/onnx_pretrained.py`` to decode
|
|
||||||
waves with the generated files:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
./pruned_transducer_stateless3/onnx_pretrained.py \
|
./pruned_transducer_stateless7_streaming/export-onnx.py \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \
|
--use-averaged-model 0 \
|
||||||
--decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \
|
--epoch 99 \
|
||||||
--joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \
|
--avg 1 \
|
||||||
--joiner-encoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_encoder_proj.onnx \
|
--decode-chunk-len 32 \
|
||||||
--joiner-decoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_decoder_proj.onnx \
|
--exp-dir $repo/exp/
|
||||||
/path/to/foo.wav \
|
|
||||||
/path/to/bar.wav \
|
|
||||||
/path/to/baz.wav
|
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
How to use the exported model
|
``export-onnx.py`` from different recipes has different options.
|
||||||
-----------------------------
|
|
||||||
|
|
||||||
We also provide `<https://github.com/k2-fsa/sherpa-onnx>`_
|
In the above example, ``--decode-chunk-len`` is specific for the
|
||||||
performing speech recognition using `onnxruntime <https://github.com/microsoft/onnxruntime>`_
|
streaming Zipformer. Other models won't have such an option.
|
||||||
with exported models.
|
|
||||||
It has been tested on Linux, macOS, and Windows.
|
It will generate the following 3 files in ``$repo/exp``
|
||||||
|
|
||||||
|
- ``encoder-epoch-99-avg-1.onnx``
|
||||||
|
- ``decoder-epoch-99-avg-1.onnx``
|
||||||
|
- ``joiner-epoch-99-avg-1.onnx``
|
||||||
|
|
||||||
|
Decode sound files with exported ONNX models
|
||||||
|
--------------------------------------------
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
./pruned_transducer_stateless7_streaming/onnx_pretrained.py \
|
||||||
|
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||||
|
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||||
|
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav
|
||||||
|
@ -580,12 +580,11 @@ for ``pnnx``:
|
|||||||
iter=468000
|
iter=468000
|
||||||
avg=16
|
avg=16
|
||||||
|
|
||||||
./lstm_transducer_stateless2/export.py \
|
./lstm_transducer_stateless2/export-for-ncnn.py \
|
||||||
--exp-dir ./lstm_transducer_stateless2/exp \
|
--exp-dir ./lstm_transducer_stateless2/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
--iter $iter \
|
--iter $iter \
|
||||||
--avg $avg \
|
--avg $avg
|
||||||
--pnnx 1
|
|
||||||
|
|
||||||
It will generate 3 files:
|
It will generate 3 files:
|
||||||
|
|
||||||
@ -615,7 +614,7 @@ To use the above generated files, run:
|
|||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
./lstm_transducer_stateless2/ncnn-decode.py \
|
./lstm_transducer_stateless2/ncnn-decode.py \
|
||||||
--bpe-model-filename ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param \
|
--encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param \
|
||||||
--encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin \
|
--encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin \
|
||||||
--decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param \
|
--decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param \
|
||||||
@ -627,7 +626,7 @@ To use the above generated files, run:
|
|||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
./lstm_transducer_stateless2/streaming-ncnn-decode.py \
|
./lstm_transducer_stateless2/streaming-ncnn-decode.py \
|
||||||
--bpe-model-filename ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param \
|
--encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param \
|
||||||
--encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin \
|
--encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin \
|
||||||
--decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param \
|
--decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param \
|
||||||
@ -657,6 +656,3 @@ by visiting the following links:
|
|||||||
|
|
||||||
You can find more usages of the pretrained models in
|
You can find more usages of the pretrained models in
|
||||||
`<https://k2-fsa.github.io/sherpa/python/streaming_asr/lstm/index.html>`_
|
`<https://k2-fsa.github.io/sherpa/python/streaming_asr/lstm/index.html>`_
|
||||||
|
|
||||||
Export ConvEmformer transducer models for ncnn
|
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
|
@ -169,9 +169,11 @@ class ConvolutionModule(nn.Module):
|
|||||||
channels: int,
|
channels: int,
|
||||||
kernel_size: int,
|
kernel_size: int,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
|
is_pnnx: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Construct an ConvolutionModule object."""
|
"""Construct an ConvolutionModule object."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.is_pnnx = is_pnnx
|
||||||
# kernerl_size should be an odd number for 'SAME' padding
|
# kernerl_size should be an odd number for 'SAME' padding
|
||||||
assert (kernel_size - 1) % 2 == 0, kernel_size
|
assert (kernel_size - 1) % 2 == 0, kernel_size
|
||||||
|
|
||||||
@ -383,8 +385,10 @@ class ConvolutionModule(nn.Module):
|
|||||||
- output right_context of shape (R, B, D).
|
- output right_context of shape (R, B, D).
|
||||||
- updated cache tensor of shape (B, D, cache_size).
|
- updated cache tensor of shape (B, D, cache_size).
|
||||||
"""
|
"""
|
||||||
# U, B, D = utterance.size()
|
if self.is_pnnx is False:
|
||||||
# R, _, _ = right_context.size()
|
U, B, D = utterance.size()
|
||||||
|
R, _, _ = right_context.size()
|
||||||
|
else:
|
||||||
U = self.chunk_length
|
U = self.chunk_length
|
||||||
B = 1
|
B = 1
|
||||||
D = self.channels
|
D = self.channels
|
||||||
@ -448,8 +452,10 @@ class EmformerAttention(nn.Module):
|
|||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
tanh_on_mem: bool = False,
|
tanh_on_mem: bool = False,
|
||||||
negative_inf: float = -1e8,
|
negative_inf: float = -1e8,
|
||||||
|
is_pnnx: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.is_pnnx = is_pnnx
|
||||||
|
|
||||||
if embed_dim % nhead != 0:
|
if embed_dim % nhead != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -539,10 +545,11 @@ class EmformerAttention(nn.Module):
|
|||||||
left_context_val: Optional[torch.Tensor] = None,
|
left_context_val: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""Underlying chunk-wise attention implementation."""
|
"""Underlying chunk-wise attention implementation."""
|
||||||
# U, B, _ = utterance.size()
|
if self.is_pnnx is False:
|
||||||
# R = right_context.size(0)
|
U, B, _ = utterance.size()
|
||||||
# M = memory.size(0)
|
R = right_context.size(0)
|
||||||
|
M = memory.size(0)
|
||||||
|
else:
|
||||||
U = self.chunk_length
|
U = self.chunk_length
|
||||||
B = 1
|
B = 1
|
||||||
R = self.right_context_length
|
R = self.right_context_length
|
||||||
@ -570,21 +577,29 @@ class EmformerAttention(nn.Module):
|
|||||||
|
|
||||||
# KV = key.size(0)
|
# KV = key.size(0)
|
||||||
|
|
||||||
|
if self.is_pnnx is True:
|
||||||
reshaped_query = query.view(Q, self.nhead, self.head_dim).permute(1, 0, 2)
|
reshaped_query = query.view(Q, self.nhead, self.head_dim).permute(1, 0, 2)
|
||||||
reshaped_key = key.view(M + R + U + L, self.nhead, self.head_dim).permute(
|
reshaped_key = key.view(M + R + U + L, self.nhead, self.head_dim).permute(
|
||||||
1, 0, 2
|
1, 0, 2
|
||||||
)
|
)
|
||||||
reshaped_value = value.view(M + R + U + L, self.nhead, self.head_dim).permute(
|
reshaped_value = value.view(
|
||||||
1, 0, 2
|
M + R + U + L, self.nhead, self.head_dim
|
||||||
)
|
).permute(1, 0, 2)
|
||||||
|
else:
|
||||||
# reshaped_query, reshaped_key, reshaped_value = [
|
reshaped_query, reshaped_key, reshaped_value = [
|
||||||
# tensor.contiguous().view(-1, B * self.nhead, self.head_dim).transpose(0, 1)
|
tensor.contiguous()
|
||||||
# for tensor in [query, key, value]
|
.view(-1, B * self.nhead, self.head_dim)
|
||||||
# ] # (B * nhead, Q or KV, head_dim)
|
.transpose(0, 1)
|
||||||
|
for tensor in [query, key, value]
|
||||||
|
] # (B * nhead, Q or KV, head_dim)
|
||||||
|
if self.is_pnnx is True:
|
||||||
attention_weights = torch.bmm(
|
attention_weights = torch.bmm(
|
||||||
reshaped_query * scaling, reshaped_key.permute(0, 2, 1)
|
reshaped_query * scaling, reshaped_key.permute(0, 2, 1)
|
||||||
) # (B * nhead, Q, KV)
|
) # (B * nhead, Q, KV)
|
||||||
|
else:
|
||||||
|
attention_weights = torch.bmm(
|
||||||
|
reshaped_query * scaling, reshaped_key.transpose(1, 2)
|
||||||
|
) # (B * nhead, Q, KV)
|
||||||
|
|
||||||
# compute attention probabilities
|
# compute attention probabilities
|
||||||
if False:
|
if False:
|
||||||
@ -597,10 +612,15 @@ class EmformerAttention(nn.Module):
|
|||||||
# compute attention outputs
|
# compute attention outputs
|
||||||
attention = torch.bmm(attention_probs, reshaped_value)
|
attention = torch.bmm(attention_probs, reshaped_value)
|
||||||
assert attention.shape == (B * self.nhead, Q, self.head_dim)
|
assert attention.shape == (B * self.nhead, Q, self.head_dim)
|
||||||
|
if self.is_pnnx is True:
|
||||||
attention = attention.permute(1, 0, 2).reshape(-1, self.embed_dim)
|
attention = attention.permute(1, 0, 2).reshape(-1, self.embed_dim)
|
||||||
# TODO(fangjun): ncnn does not support reshape(-1, 1, self.embed_dim)
|
# TODO(fangjun): ncnn does not support reshape(-1, 1, self.embed_dim)
|
||||||
# We have to change InnerProduct in ncnn to ignore the extra dim below
|
# We have to change InnerProduct in ncnn to ignore the extra dim below
|
||||||
attention = attention.unsqueeze(1)
|
attention = attention.unsqueeze(1)
|
||||||
|
else:
|
||||||
|
attention = (
|
||||||
|
attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim)
|
||||||
|
)
|
||||||
|
|
||||||
# apply output projection
|
# apply output projection
|
||||||
output_right_context_utterance = self.out_proj(attention)
|
output_right_context_utterance = self.out_proj(attention)
|
||||||
@ -733,11 +753,12 @@ class EmformerAttention(nn.Module):
|
|||||||
- attention value of left context and utterance, which would be
|
- attention value of left context and utterance, which would be
|
||||||
cached for next computation, with shape (L + U, B, D).
|
cached for next computation, with shape (L + U, B, D).
|
||||||
"""
|
"""
|
||||||
# U = utterance.size(0)
|
if self.is_pnnx is False:
|
||||||
# R = right_context.size(0)
|
U = utterance.size(0)
|
||||||
# L = left_context_key.size(0)
|
R = right_context.size(0)
|
||||||
# M = memory.size(0)
|
L = left_context_key.size(0)
|
||||||
|
M = memory.size(0)
|
||||||
|
else:
|
||||||
U = self.chunk_length
|
U = self.chunk_length
|
||||||
R = self.right_context_length
|
R = self.right_context_length
|
||||||
L = self.left_context_length
|
L = self.left_context_length
|
||||||
@ -811,6 +832,7 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
memory_size: int = 0,
|
memory_size: int = 0,
|
||||||
tanh_on_mem: bool = False,
|
tanh_on_mem: bool = False,
|
||||||
negative_inf: float = -1e8,
|
negative_inf: float = -1e8,
|
||||||
|
is_pnnx: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -824,6 +846,7 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
tanh_on_mem=tanh_on_mem,
|
tanh_on_mem=tanh_on_mem,
|
||||||
negative_inf=negative_inf,
|
negative_inf=negative_inf,
|
||||||
|
is_pnnx=is_pnnx,
|
||||||
)
|
)
|
||||||
self.summary_op = nn.AvgPool1d(
|
self.summary_op = nn.AvgPool1d(
|
||||||
kernel_size=chunk_length, stride=chunk_length, ceil_mode=True
|
kernel_size=chunk_length, stride=chunk_length, ceil_mode=True
|
||||||
@ -850,6 +873,7 @@ class EmformerEncoderLayer(nn.Module):
|
|||||||
right_context_length,
|
right_context_length,
|
||||||
d_model,
|
d_model,
|
||||||
cnn_module_kernel,
|
cnn_module_kernel,
|
||||||
|
is_pnnx=is_pnnx,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.norm_final = BasicNorm(d_model)
|
self.norm_final = BasicNorm(d_model)
|
||||||
@ -1204,6 +1228,7 @@ class EmformerEncoder(nn.Module):
|
|||||||
memory_size: int = 0,
|
memory_size: int = 0,
|
||||||
tanh_on_mem: bool = False,
|
tanh_on_mem: bool = False,
|
||||||
negative_inf: float = -1e8,
|
negative_inf: float = -1e8,
|
||||||
|
is_pnnx: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -1229,6 +1254,7 @@ class EmformerEncoder(nn.Module):
|
|||||||
memory_size=memory_size,
|
memory_size=memory_size,
|
||||||
tanh_on_mem=tanh_on_mem,
|
tanh_on_mem=tanh_on_mem,
|
||||||
negative_inf=negative_inf,
|
negative_inf=negative_inf,
|
||||||
|
is_pnnx=is_pnnx,
|
||||||
)
|
)
|
||||||
for layer_idx in range(num_encoder_layers)
|
for layer_idx in range(num_encoder_layers)
|
||||||
]
|
]
|
||||||
@ -1561,6 +1587,20 @@ class Emformer(EncoderInterface):
|
|||||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model, is_pnnx=is_pnnx)
|
self.encoder_embed = Conv2dSubsampling(num_features, d_model, is_pnnx=is_pnnx)
|
||||||
self.is_pnnx = is_pnnx
|
self.is_pnnx = is_pnnx
|
||||||
|
|
||||||
|
self.num_encoder_layers = num_encoder_layers
|
||||||
|
self.memory_size = memory_size
|
||||||
|
self.d_model = d_model
|
||||||
|
self.cnn_module_kernel = cnn_module_kernel
|
||||||
|
self.left_context_length = left_context_length // subsampling_factor
|
||||||
|
self.right_context_length = right_context_length
|
||||||
|
self.subsampling_factor = subsampling_factor
|
||||||
|
|
||||||
|
assert subsampling_factor == 4, subsampling_factor
|
||||||
|
pad_length = right_context_length + 2 * 4 + 3
|
||||||
|
|
||||||
|
# before subsampling
|
||||||
|
self.T = self.chunk_length + pad_length
|
||||||
|
|
||||||
self.encoder = EmformerEncoder(
|
self.encoder = EmformerEncoder(
|
||||||
chunk_length=chunk_length // subsampling_factor,
|
chunk_length=chunk_length // subsampling_factor,
|
||||||
d_model=d_model,
|
d_model=d_model,
|
||||||
@ -1575,6 +1615,7 @@ class Emformer(EncoderInterface):
|
|||||||
memory_size=memory_size,
|
memory_size=memory_size,
|
||||||
tanh_on_mem=tanh_on_mem,
|
tanh_on_mem=tanh_on_mem,
|
||||||
negative_inf=negative_inf,
|
negative_inf=negative_inf,
|
||||||
|
is_pnnx=is_pnnx,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _forward(
|
def _forward(
|
||||||
@ -1691,7 +1732,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
layer1_channels: int = 8,
|
layer1_channels: int = 8,
|
||||||
layer2_channels: int = 32,
|
layer2_channels: int = 32,
|
||||||
layer3_channels: int = 128,
|
layer3_channels: int = 128,
|
||||||
is_pnnx: bool = False,
|
is_pnnx: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -1767,7 +1808,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
|
|
||||||
if torch.jit.is_tracing() and self.is_pnnx:
|
if torch.jit.is_tracing() and self.is_pnnx is True:
|
||||||
x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim)
|
x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim)
|
||||||
x = self.out(x)
|
x = self.out(x)
|
||||||
else:
|
else:
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
Please see
|
||||||
|
https://k2-fsa.github.io/icefall/model-export/export-ncnn.html
|
||||||
|
for more details about how to use this file.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
./conv_emformer_transducer_stateless2/export-for-ncnn.py \
|
./conv_emformer_transducer_stateless2/export-for-ncnn.py \
|
||||||
--exp-dir ./conv_emformer_transducer_stateless2/exp \
|
--exp-dir ./conv_emformer_transducer_stateless2/exp \
|
||||||
@ -44,7 +48,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -96,14 +100,6 @@ def get_parser():
|
|||||||
help="Path to the BPE model",
|
help="Path to the BPE model",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--jit",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""True to save a model after applying torch.jit.script.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -217,6 +213,8 @@ def main():
|
|||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn")
|
||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
@ -330,5 +328,4 @@ def main():
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
|
||||||
main()
|
main()
|
||||||
|
644
egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py
Executable file
644
egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py
Executable file
@ -0,0 +1,644 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script exports a transducer model from PyTorch to ONNX.
|
||||||
|
|
||||||
|
We use the pre-trained model from
|
||||||
|
https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05
|
||||||
|
as an example to show how to use this file.
|
||||||
|
|
||||||
|
1. Download the pre-trained model
|
||||||
|
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||||
|
git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt"
|
||||||
|
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
2. Export the model to ONNX
|
||||||
|
|
||||||
|
./conv_emformer_transducer_stateless2/export-onnx.py \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--num-encoder-layers 12 \
|
||||||
|
--chunk-length 32 \
|
||||||
|
--cnn-module-kernel 31 \
|
||||||
|
--left-context-length 32 \
|
||||||
|
--right-context-length 8 \
|
||||||
|
--memory-size 32
|
||||||
|
|
||||||
|
It will generate the following 3 files inside $repo/exp:
|
||||||
|
|
||||||
|
- encoder-epoch-99-avg-1.onnx
|
||||||
|
- decoder-epoch-99-avg-1.onnx
|
||||||
|
- joiner-epoch-99-avg-1.onnx
|
||||||
|
|
||||||
|
See ./onnx_pretrained.py for how to
|
||||||
|
use the exported ONNX models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from decoder import Decoder
|
||||||
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
|
from train2 import add_model_arguments, get_params, get_transducer_model
|
||||||
|
from emformer import Emformer
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.utils import setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=28,
|
||||||
|
help="""It specifies the checkpoint to use for averaging.
|
||||||
|
Note: Epoch counts from 0.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="pruned_transducer_stateless5/exp",
|
||||||
|
help="""It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-model",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_bpe_500/bpe.model",
|
||||||
|
help="Path to the BPE model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||||
|
"""Add meta data to an ONNX model. It is changed in-place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Filename of the ONNX model to be changed.
|
||||||
|
meta_data:
|
||||||
|
Key-value pairs.
|
||||||
|
"""
|
||||||
|
model = onnx.load(filename)
|
||||||
|
for key, value in meta_data.items():
|
||||||
|
meta = model.metadata_props.add()
|
||||||
|
meta.key = key
|
||||||
|
meta.value = value
|
||||||
|
|
||||||
|
onnx.save(model, filename)
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxEncoder(nn.Module):
|
||||||
|
"""A wrapper for Emformer and the encoder_proj from the joiner"""
|
||||||
|
|
||||||
|
def __init__(self, encoder: Emformer, encoder_proj: nn.Linear):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder:
|
||||||
|
A Emformer encoder.
|
||||||
|
encoder_proj:
|
||||||
|
The projection layer for encoder from the joiner.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = encoder
|
||||||
|
self.encoder_proj = encoder_proj
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_lens: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Please see the help information of Emformer.forward
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A 3-D tensor of shape (N, T, C)
|
||||||
|
x_lens:
|
||||||
|
A 1-D tensor of shape (N,). Its dtype is torch.int64
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing:
|
||||||
|
- encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
|
||||||
|
- encoder_out_lens, A 1-D tensor of shape (N,)
|
||||||
|
"""
|
||||||
|
encoder_out, encoder_out_lens = self.encoder(x, x_lens)
|
||||||
|
|
||||||
|
encoder_out = self.encoder_proj(encoder_out)
|
||||||
|
# Now encoder_out is of shape (N, T, joiner_dim)
|
||||||
|
|
||||||
|
return encoder_out, encoder_out_lens
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxDecoder(nn.Module):
|
||||||
|
"""A wrapper for Decoder and the decoder_proj from the joiner"""
|
||||||
|
|
||||||
|
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
|
||||||
|
super().__init__()
|
||||||
|
self.decoder = decoder
|
||||||
|
self.decoder_proj = decoder_proj
|
||||||
|
|
||||||
|
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
y:
|
||||||
|
A 2-D tensor of shape (N, context_size).
|
||||||
|
Returns
|
||||||
|
Return a 2-D tensor of shape (N, joiner_dim)
|
||||||
|
"""
|
||||||
|
need_pad = False
|
||||||
|
decoder_output = self.decoder(y, need_pad=need_pad)
|
||||||
|
decoder_output = decoder_output.squeeze(1)
|
||||||
|
output = self.decoder_proj(decoder_output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxJoiner(nn.Module):
|
||||||
|
"""A wrapper for the joiner"""
|
||||||
|
|
||||||
|
def __init__(self, output_linear: nn.Linear):
|
||||||
|
super().__init__()
|
||||||
|
self.output_linear = output_linear
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
decoder_out: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder_out:
|
||||||
|
A 2-D tensor of shape (N, joiner_dim)
|
||||||
|
decoder_out:
|
||||||
|
A 2-D tensor of shape (N, joiner_dim)
|
||||||
|
Returns:
|
||||||
|
Return a 2-D tensor of shape (N, vocab_size)
|
||||||
|
"""
|
||||||
|
logit = encoder_out + decoder_out
|
||||||
|
logit = self.output_linear(torch.tanh(logit))
|
||||||
|
return logit
|
||||||
|
|
||||||
|
|
||||||
|
def export_encoder_model_onnx(
|
||||||
|
encoder_model: OnnxEncoder,
|
||||||
|
encoder_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given encoder model to ONNX format.
|
||||||
|
The exported model has the following inputs:
|
||||||
|
|
||||||
|
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
||||||
|
- a list of states (each layers has 4 states)
|
||||||
|
|
||||||
|
|
||||||
|
and it has two outputs:
|
||||||
|
|
||||||
|
- encoder_out, a tensor of shape (N, T', joiner_dim)
|
||||||
|
- a list of states (each layers has 4 states)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_model:
|
||||||
|
The input encoder model
|
||||||
|
encoder_filename:
|
||||||
|
The filename to save the exported ONNX model.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
num_encoder_layers = encoder_model.encoder.num_encoder_layers
|
||||||
|
memory_size = encoder_model.encoder.memory_size
|
||||||
|
cnn_module_kernel = encoder_model.encoder.cnn_module_kernel
|
||||||
|
chunk_length = encoder_model.encoder.chunk_length
|
||||||
|
right_context_length = encoder_model.encoder.right_context_length
|
||||||
|
encoder_dim = encoder_model.encoder.d_model
|
||||||
|
left_context_length = encoder_model.encoder.left_context_length
|
||||||
|
|
||||||
|
T = encoder_model.encoder.T
|
||||||
|
|
||||||
|
logging.info(f"num_encoder_layers={num_encoder_layers}")
|
||||||
|
logging.info(f"memory_size={memory_size}")
|
||||||
|
logging.info(f"cnn_module_kernel={cnn_module_kernel}")
|
||||||
|
logging.info(f"chunk_length={chunk_length}")
|
||||||
|
logging.info(f"right_context_length={right_context_length}")
|
||||||
|
logging.info(f"encoder_dim={encoder_dim}")
|
||||||
|
logging.info(f"left_context_length={left_context_length} (after subsampling)")
|
||||||
|
logging.info(f"T={T}")
|
||||||
|
|
||||||
|
meta_data = {
|
||||||
|
"model_type": "conv-emformer",
|
||||||
|
"version": "1",
|
||||||
|
"model_author": "k2-fsa",
|
||||||
|
"decode_chunk_len": str(chunk_length), # 32
|
||||||
|
"T": str(T), # 32
|
||||||
|
"num_encoder_layers": str(num_encoder_layers),
|
||||||
|
"memory_size": str(memory_size),
|
||||||
|
"cnn_module_kernel": str(cnn_module_kernel),
|
||||||
|
"right_context_length": str(right_context_length),
|
||||||
|
"left_context_length": str(left_context_length),
|
||||||
|
"encoder_dim": str(encoder_dim),
|
||||||
|
}
|
||||||
|
logging.info(f"meta_data: {meta_data}")
|
||||||
|
|
||||||
|
x = torch.zeros(1, T, 80, dtype=torch.float32)
|
||||||
|
states = encoder_model.encoder.init_states()
|
||||||
|
|
||||||
|
# Each layer has 4 states
|
||||||
|
assert len(states) == num_encoder_layers * 4, (len(states), num_encoder_layers)
|
||||||
|
# layer 0:
|
||||||
|
# state0: (memory_size, 1, encoder_dim)
|
||||||
|
# state1: (left_context_length, 1, encoder_dim)
|
||||||
|
# state2: (left_context_length, 1, encoder_dim)
|
||||||
|
# state3: (1, encoder_dim, cnn_module_kernel-1)
|
||||||
|
|
||||||
|
inputs = {}
|
||||||
|
input_names = ["x"]
|
||||||
|
|
||||||
|
outputs = {}
|
||||||
|
output_names = ["encoder_out"]
|
||||||
|
|
||||||
|
def build_inputs_outputs(s, name):
|
||||||
|
assert len(s) == 4, len(s)
|
||||||
|
logging.info(f"{name}_0.shape: {s[0].shape}")
|
||||||
|
input_names.append(f"{name}_0")
|
||||||
|
inputs[f"{name}_0"] = {1: "N"}
|
||||||
|
output_names.append(f"new_{name}_0")
|
||||||
|
|
||||||
|
logging.info(f"{name}_1.shape: {s[1].shape}")
|
||||||
|
input_names.append(f"{name}_1")
|
||||||
|
inputs[f"{name}_1"] = {1: "N"}
|
||||||
|
output_names.append(f"new_{name}_1")
|
||||||
|
|
||||||
|
logging.info(f"{name}_2.shape: {s[2].shape}")
|
||||||
|
input_names.append(f"{name}_2")
|
||||||
|
inputs[f"{name}_2"] = {1: "N"}
|
||||||
|
output_names.append(f"new_{name}_2")
|
||||||
|
|
||||||
|
logging.info(f"{name}_3.shape: {s[3].shape}")
|
||||||
|
input_names.append(f"{name}_3")
|
||||||
|
inputs[f"{name}_3"] = {0: "N"}
|
||||||
|
output_names.append(f"new_{name}_3")
|
||||||
|
|
||||||
|
for i in range(num_encoder_layers):
|
||||||
|
base_name = f"layer{i}"
|
||||||
|
s = states[i * 4 : (i + 1) * 4]
|
||||||
|
build_inputs_outputs(s, base_name)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
encoder_model,
|
||||||
|
(x, states),
|
||||||
|
encoder_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=input_names,
|
||||||
|
output_names=output_names,
|
||||||
|
dynamic_axes={
|
||||||
|
"x": {0: "N"},
|
||||||
|
"encoder_out": {0: "N"},
|
||||||
|
**inputs,
|
||||||
|
**outputs,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
add_meta_data(filename=encoder_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
def export_decoder_model_onnx(
|
||||||
|
decoder_model: OnnxDecoder,
|
||||||
|
decoder_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the decoder model to ONNX format.
|
||||||
|
|
||||||
|
The exported model has one input:
|
||||||
|
|
||||||
|
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||||
|
|
||||||
|
and has one output:
|
||||||
|
|
||||||
|
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder_model:
|
||||||
|
The decoder model to be exported.
|
||||||
|
decoder_filename:
|
||||||
|
Filename to save the exported ONNX model.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
context_size = decoder_model.decoder.context_size
|
||||||
|
vocab_size = decoder_model.decoder.vocab_size
|
||||||
|
|
||||||
|
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||||
|
torch.onnx.export(
|
||||||
|
decoder_model,
|
||||||
|
y,
|
||||||
|
decoder_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["y"],
|
||||||
|
output_names=["decoder_out"],
|
||||||
|
dynamic_axes={
|
||||||
|
"y": {0: "N"},
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
meta_data = {
|
||||||
|
"context_size": str(context_size),
|
||||||
|
"vocab_size": str(vocab_size),
|
||||||
|
}
|
||||||
|
add_meta_data(filename=decoder_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
def export_joiner_model_onnx(
|
||||||
|
joiner_model: nn.Module,
|
||||||
|
joiner_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the joiner model to ONNX format.
|
||||||
|
The exported joiner model has two inputs:
|
||||||
|
|
||||||
|
- encoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
- decoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
|
||||||
|
and produces one output:
|
||||||
|
|
||||||
|
- logit: a tensor of shape (N, vocab_size)
|
||||||
|
"""
|
||||||
|
joiner_dim = joiner_model.output_linear.weight.shape[1]
|
||||||
|
logging.info(f"joiner dim: {joiner_dim}")
|
||||||
|
|
||||||
|
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||||
|
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
joiner_model,
|
||||||
|
(projected_encoder_out, projected_decoder_out),
|
||||||
|
joiner_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=[
|
||||||
|
"encoder_out",
|
||||||
|
"decoder_out",
|
||||||
|
],
|
||||||
|
output_names=["logit"],
|
||||||
|
dynamic_axes={
|
||||||
|
"encoder_out": {0: "N"},
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
"logit": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
meta_data = {
|
||||||
|
"joiner_dim": str(joiner_dim),
|
||||||
|
}
|
||||||
|
add_meta_data(filename=joiner_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
params.is_pnnx = False
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
setup_logger(f"{params.exp_dir}/log-export/log-export-onnx")
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to("cpu")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
|
|
||||||
|
encoder = OnnxEncoder(
|
||||||
|
encoder=model.encoder,
|
||||||
|
encoder_proj=model.joiner.encoder_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder = OnnxDecoder(
|
||||||
|
decoder=model.decoder,
|
||||||
|
decoder_proj=model.joiner.decoder_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
|
||||||
|
|
||||||
|
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
|
||||||
|
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
|
||||||
|
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
|
||||||
|
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
|
||||||
|
logging.info(f"encoder parameters: {encoder_num_param}")
|
||||||
|
logging.info(f"decoder parameters: {decoder_num_param}")
|
||||||
|
logging.info(f"joiner parameters: {joiner_num_param}")
|
||||||
|
logging.info(f"total parameters: {total_num_param}")
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
suffix = f"iter-{params.iter}"
|
||||||
|
else:
|
||||||
|
suffix = f"epoch-{params.epoch}"
|
||||||
|
|
||||||
|
suffix += f"-avg-{params.avg}"
|
||||||
|
|
||||||
|
opset_version = 13
|
||||||
|
|
||||||
|
logging.info("Exporting encoder")
|
||||||
|
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
|
||||||
|
export_encoder_model_onnx(
|
||||||
|
encoder,
|
||||||
|
encoder_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported encoder to {encoder_filename}")
|
||||||
|
|
||||||
|
logging.info("Exporting decoder")
|
||||||
|
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
|
||||||
|
export_decoder_model_onnx(
|
||||||
|
decoder,
|
||||||
|
decoder_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported decoder to {decoder_filename}")
|
||||||
|
|
||||||
|
logging.info("Exporting joiner")
|
||||||
|
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
|
||||||
|
export_joiner_model_onnx(
|
||||||
|
joiner,
|
||||||
|
joiner_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported joiner to {joiner_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
main()
|
@ -64,6 +64,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
@ -258,6 +259,7 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if params.jit:
|
if params.jit:
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
# We won't use the forward() method of the model in C++, so just ignore
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
# it here.
|
# it here.
|
||||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||||
|
456
egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py
Executable file
456
egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py
Executable file
@ -0,0 +1,456 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script loads ONNX models exported by ./export-onnx.py
|
||||||
|
and uses them to decode waves.
|
||||||
|
|
||||||
|
We use the pre-trained model from
|
||||||
|
https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05
|
||||||
|
as an example to show how to use this file.
|
||||||
|
|
||||||
|
1. Download the pre-trained model
|
||||||
|
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||||
|
git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt"
|
||||||
|
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
2. Export the model to ONNX
|
||||||
|
|
||||||
|
./conv_emformer_transducer_stateless2/export-onnx.py \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--num-encoder-layers 12 \
|
||||||
|
--chunk-length 32 \
|
||||||
|
--cnn-module-kernel 31 \
|
||||||
|
--left-context-length 32 \
|
||||||
|
--right-context-length 8 \
|
||||||
|
--memory-size 32
|
||||||
|
|
||||||
|
It will generate the following 3 files inside $repo/exp:
|
||||||
|
|
||||||
|
- encoder-epoch-99-avg-1.onnx
|
||||||
|
- decoder-epoch-99-avg-1.onnx
|
||||||
|
- joiner-epoch-99-avg-1.onnx
|
||||||
|
|
||||||
|
3. Run this file with the exported ONNX models
|
||||||
|
|
||||||
|
./conv_emformer_transducer_stateless2/onnx_pretrained.py \
|
||||||
|
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||||
|
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||||
|
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav
|
||||||
|
|
||||||
|
Note: Even though this script only supports decoding a single file,
|
||||||
|
the exported ONNX models do support batch processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime as ort
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the encoder onnx model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the decoder onnx model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--joiner-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the joiner onnx model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=str,
|
||||||
|
help="""Path to tokens.txt.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_file",
|
||||||
|
type=str,
|
||||||
|
help="The input sound file to transcribe. "
|
||||||
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
|
"For example, wav and flac are supported. "
|
||||||
|
"The sample rate has to be 16kHz.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxModel:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder_model_filename: str,
|
||||||
|
decoder_model_filename: str,
|
||||||
|
joiner_model_filename: str,
|
||||||
|
):
|
||||||
|
session_opts = ort.SessionOptions()
|
||||||
|
session_opts.inter_op_num_threads = 1
|
||||||
|
session_opts.intra_op_num_threads = 1
|
||||||
|
|
||||||
|
self.session_opts = session_opts
|
||||||
|
|
||||||
|
self.init_encoder(encoder_model_filename)
|
||||||
|
self.init_decoder(decoder_model_filename)
|
||||||
|
self.init_joiner(joiner_model_filename)
|
||||||
|
|
||||||
|
def init_encoder(self, encoder_model_filename: str):
|
||||||
|
self.encoder = ort.InferenceSession(
|
||||||
|
encoder_model_filename,
|
||||||
|
sess_options=self.session_opts,
|
||||||
|
)
|
||||||
|
self.init_encoder_states()
|
||||||
|
|
||||||
|
def init_encoder_states(self, batch_size: int = 1):
|
||||||
|
encoder_meta = self.encoder.get_modelmeta().custom_metadata_map
|
||||||
|
|
||||||
|
model_type = encoder_meta["model_type"]
|
||||||
|
assert model_type == "conv-emformer", model_type
|
||||||
|
|
||||||
|
decode_chunk_len = int(encoder_meta["decode_chunk_len"])
|
||||||
|
T = int(encoder_meta["T"])
|
||||||
|
|
||||||
|
num_encoder_layers = int(encoder_meta["num_encoder_layers"])
|
||||||
|
memory_size = int(encoder_meta["memory_size"])
|
||||||
|
cnn_module_kernel = int(encoder_meta["cnn_module_kernel"])
|
||||||
|
right_context_length = int(encoder_meta["right_context_length"])
|
||||||
|
left_context_length = int(encoder_meta["left_context_length"])
|
||||||
|
encoder_dim = int(encoder_meta["encoder_dim"])
|
||||||
|
|
||||||
|
logging.info(f"decode_chunk_len: {decode_chunk_len}")
|
||||||
|
logging.info(f"T: {T}")
|
||||||
|
logging.info(f"num_encoder_layers: {num_encoder_layers}")
|
||||||
|
logging.info(f"memory_size: {memory_size}")
|
||||||
|
logging.info(f"cnn_module_kernel: {cnn_module_kernel}")
|
||||||
|
logging.info(f"left_context_length: {left_context_length} (after subsampling)")
|
||||||
|
logging.info(f"right_context_length: {right_context_length}")
|
||||||
|
logging.info(f"encoder_dim: {encoder_dim}")
|
||||||
|
|
||||||
|
N = batch_size
|
||||||
|
|
||||||
|
states = []
|
||||||
|
for i in range(num_encoder_layers):
|
||||||
|
s0 = torch.zeros(memory_size, N, encoder_dim)
|
||||||
|
s1 = torch.zeros(left_context_length, N, encoder_dim)
|
||||||
|
s2 = torch.zeros(left_context_length, N, encoder_dim)
|
||||||
|
s3 = torch.zeros(N, encoder_dim, cnn_module_kernel - 1)
|
||||||
|
states.extend([s0, s1, s2, s3])
|
||||||
|
|
||||||
|
self.states = states
|
||||||
|
|
||||||
|
self.segment = T
|
||||||
|
self.offset = decode_chunk_len
|
||||||
|
self.num_encoder_layers = num_encoder_layers
|
||||||
|
|
||||||
|
def init_decoder(self, decoder_model_filename: str):
|
||||||
|
self.decoder = ort.InferenceSession(
|
||||||
|
decoder_model_filename,
|
||||||
|
sess_options=self.session_opts,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||||
|
self.context_size = int(decoder_meta["context_size"])
|
||||||
|
self.vocab_size = int(decoder_meta["vocab_size"])
|
||||||
|
|
||||||
|
logging.info(f"context_size: {self.context_size}")
|
||||||
|
logging.info(f"vocab_size: {self.vocab_size}")
|
||||||
|
|
||||||
|
def init_joiner(self, joiner_model_filename: str):
|
||||||
|
self.joiner = ort.InferenceSession(
|
||||||
|
joiner_model_filename,
|
||||||
|
sess_options=self.session_opts,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||||
|
self.joiner_dim = int(joiner_meta["joiner_dim"])
|
||||||
|
|
||||||
|
logging.info(f"joiner_dim: {self.joiner_dim}")
|
||||||
|
|
||||||
|
def _build_encoder_input_output(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
) -> Tuple[Dict[str, np.ndarray], List[str]]:
|
||||||
|
encoder_input = {"x": x.numpy()}
|
||||||
|
encoder_output = ["encoder_out"]
|
||||||
|
|
||||||
|
def build_inputs_outputs(states: List[torch.Tensor], name: str):
|
||||||
|
for i in range(4):
|
||||||
|
if isinstance(states[i], torch.Tensor):
|
||||||
|
encoder_input[f"{name}_{i}"] = states[i].numpy()
|
||||||
|
else:
|
||||||
|
encoder_input[f"{name}_{i}"] = states[i]
|
||||||
|
|
||||||
|
encoder_output.append(f"new_{name}_{i}")
|
||||||
|
|
||||||
|
for i in range(self.num_encoder_layers):
|
||||||
|
base_name = f"layer{i}"
|
||||||
|
s = self.states[i * 4 : (i + 1) * 4]
|
||||||
|
build_inputs_outputs(s, base_name)
|
||||||
|
|
||||||
|
return encoder_input, encoder_output
|
||||||
|
|
||||||
|
def _update_states(self, states: List[np.ndarray]):
|
||||||
|
self.states = states
|
||||||
|
|
||||||
|
def run_encoder(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A 3-D tensor of shape (N, T, C)
|
||||||
|
Returns:
|
||||||
|
Return a 3-D tensor of shape (N, T', joiner_dim) where
|
||||||
|
T' is usually equal to ((T-7)//2+1)//2
|
||||||
|
"""
|
||||||
|
encoder_input, encoder_output_names = self._build_encoder_input_output(x)
|
||||||
|
out = self.encoder.run(encoder_output_names, encoder_input)
|
||||||
|
|
||||||
|
self._update_states(out[1:])
|
||||||
|
|
||||||
|
return torch.from_numpy(out[0])
|
||||||
|
|
||||||
|
def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
decoder_input:
|
||||||
|
A 2-D tensor of shape (N, context_size)
|
||||||
|
Returns:
|
||||||
|
Return a 2-D tensor of shape (N, joiner_dim)
|
||||||
|
"""
|
||||||
|
out = self.decoder.run(
|
||||||
|
[self.decoder.get_outputs()[0].name],
|
||||||
|
{self.decoder.get_inputs()[0].name: decoder_input.numpy()},
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
return torch.from_numpy(out)
|
||||||
|
|
||||||
|
def run_joiner(
|
||||||
|
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder_out:
|
||||||
|
A 2-D tensor of shape (N, joiner_dim)
|
||||||
|
decoder_out:
|
||||||
|
A 2-D tensor of shape (N, joiner_dim)
|
||||||
|
Returns:
|
||||||
|
Return a 2-D tensor of shape (N, vocab_size)
|
||||||
|
"""
|
||||||
|
out = self.joiner.run(
|
||||||
|
[self.joiner.get_outputs()[0].name],
|
||||||
|
{
|
||||||
|
self.joiner.get_inputs()[0].name: encoder_out.numpy(),
|
||||||
|
self.joiner.get_inputs()[1].name: decoder_out.numpy(),
|
||||||
|
},
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
return torch.from_numpy(out)
|
||||||
|
|
||||||
|
|
||||||
|
def read_sound_files(
|
||||||
|
filenames: List[str], expected_sample_rate: float
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
A list of sound filenames.
|
||||||
|
expected_sample_rate:
|
||||||
|
The expected sample rate of the sound files.
|
||||||
|
Returns:
|
||||||
|
Return a list of 1-D float32 torch tensors.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for f in filenames:
|
||||||
|
wave, sample_rate = torchaudio.load(f)
|
||||||
|
assert (
|
||||||
|
sample_rate == expected_sample_rate
|
||||||
|
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||||
|
# We use only the first channel
|
||||||
|
ans.append(wave[0].contiguous())
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def create_streaming_feature_extractor() -> OnlineFeature:
|
||||||
|
"""Create a CPU streaming feature extractor.
|
||||||
|
|
||||||
|
At present, we assume it returns a fbank feature extractor with
|
||||||
|
fixed options. In the future, we will support passing in the options
|
||||||
|
from outside.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a CPU streaming feature extractor.
|
||||||
|
"""
|
||||||
|
opts = FbankOptions()
|
||||||
|
opts.device = "cpu"
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
opts.frame_opts.samp_freq = 16000
|
||||||
|
opts.mel_opts.num_bins = 80
|
||||||
|
return OnlineFbank(opts)
|
||||||
|
|
||||||
|
|
||||||
|
def greedy_search(
|
||||||
|
model: OnnxModel,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
context_size: int,
|
||||||
|
decoder_out: Optional[torch.Tensor] = None,
|
||||||
|
hyp: Optional[List[int]] = None,
|
||||||
|
) -> List[int]:
|
||||||
|
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
The transducer model.
|
||||||
|
encoder_out:
|
||||||
|
A 3-D tensor of shape (1, T, joiner_dim)
|
||||||
|
context_size:
|
||||||
|
The context size of the decoder model.
|
||||||
|
decoder_out:
|
||||||
|
Optional. Decoder output of the previous chunk.
|
||||||
|
hyp:
|
||||||
|
Decoding results for previous chunks.
|
||||||
|
Returns:
|
||||||
|
Return the decoded results so far.
|
||||||
|
"""
|
||||||
|
|
||||||
|
blank_id = 0
|
||||||
|
|
||||||
|
if decoder_out is None:
|
||||||
|
assert hyp is None, hyp
|
||||||
|
hyp = [blank_id] * context_size
|
||||||
|
decoder_input = torch.tensor([hyp], dtype=torch.int64)
|
||||||
|
decoder_out = model.run_decoder(decoder_input)
|
||||||
|
else:
|
||||||
|
assert hyp is not None, hyp
|
||||||
|
|
||||||
|
encoder_out = encoder_out.squeeze(0)
|
||||||
|
T = encoder_out.size(0)
|
||||||
|
for t in range(T):
|
||||||
|
cur_encoder_out = encoder_out[t : t + 1]
|
||||||
|
joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0)
|
||||||
|
y = joiner_out.argmax(dim=0).item()
|
||||||
|
if y != blank_id:
|
||||||
|
hyp.append(y)
|
||||||
|
decoder_input = hyp[-context_size:]
|
||||||
|
decoder_input = torch.tensor([decoder_input], dtype=torch.int64)
|
||||||
|
decoder_out = model.run_decoder(decoder_input)
|
||||||
|
|
||||||
|
return hyp, decoder_out
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
|
||||||
|
model = OnnxModel(
|
||||||
|
encoder_model_filename=args.encoder_model_filename,
|
||||||
|
decoder_model_filename=args.decoder_model_filename,
|
||||||
|
joiner_model_filename=args.joiner_model_filename,
|
||||||
|
)
|
||||||
|
|
||||||
|
sample_rate = 16000
|
||||||
|
|
||||||
|
logging.info("Constructing Fbank computer")
|
||||||
|
online_fbank = create_streaming_feature_extractor()
|
||||||
|
|
||||||
|
logging.info(f"Reading sound files: {args.sound_file}")
|
||||||
|
waves = read_sound_files(
|
||||||
|
filenames=[args.sound_file],
|
||||||
|
expected_sample_rate=sample_rate,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32)
|
||||||
|
wave_samples = torch.cat([waves, tail_padding])
|
||||||
|
|
||||||
|
num_processed_frames = 0
|
||||||
|
segment = model.segment
|
||||||
|
offset = model.offset
|
||||||
|
|
||||||
|
context_size = model.context_size
|
||||||
|
hyp = None
|
||||||
|
decoder_out = None
|
||||||
|
|
||||||
|
chunk = int(1 * sample_rate) # 1 second
|
||||||
|
start = 0
|
||||||
|
while start < wave_samples.numel():
|
||||||
|
end = min(start + chunk, wave_samples.numel())
|
||||||
|
samples = wave_samples[start:end]
|
||||||
|
start += chunk
|
||||||
|
|
||||||
|
online_fbank.accept_waveform(
|
||||||
|
sampling_rate=sample_rate,
|
||||||
|
waveform=samples,
|
||||||
|
)
|
||||||
|
|
||||||
|
while online_fbank.num_frames_ready - num_processed_frames >= segment:
|
||||||
|
frames = []
|
||||||
|
for i in range(segment):
|
||||||
|
frames.append(online_fbank.get_frame(num_processed_frames + i))
|
||||||
|
num_processed_frames += offset
|
||||||
|
frames = torch.cat(frames, dim=0)
|
||||||
|
frames = frames.unsqueeze(0)
|
||||||
|
encoder_out = model.run_encoder(frames)
|
||||||
|
hyp, decoder_out = greedy_search(
|
||||||
|
model,
|
||||||
|
encoder_out,
|
||||||
|
context_size,
|
||||||
|
decoder_out,
|
||||||
|
hyp,
|
||||||
|
)
|
||||||
|
|
||||||
|
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
||||||
|
|
||||||
|
text = ""
|
||||||
|
for i in hyp[context_size:]:
|
||||||
|
text += symbol_table[i]
|
||||||
|
text = text.replace("▁", " ").strip()
|
||||||
|
|
||||||
|
logging.info(args.sound_file)
|
||||||
|
logging.info(text)
|
||||||
|
|
||||||
|
logging.info("Decoding Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
@ -425,6 +425,7 @@ def get_params() -> AttributeDict:
|
|||||||
"joiner_dim": 512,
|
"joiner_dim": 512,
|
||||||
# parameters for Noam
|
# parameters for Noam
|
||||||
"model_warm_step": 3000, # arg given to model, not for lrate
|
"model_warm_step": 3000, # arg given to model, not for lrate
|
||||||
|
"is_pnnx": True,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -446,6 +447,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
left_context_length=params.left_context_length,
|
left_context_length=params.left_context_length,
|
||||||
right_context_length=params.right_context_length,
|
right_context_length=params.right_context_length,
|
||||||
memory_size=params.memory_size,
|
memory_size=params.memory_size,
|
||||||
|
is_pnnx=params.is_pnnx,
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
1
egs/librispeech/ASR/lstm_transducer_stateless/export-onnx.py
Symbolic link
1
egs/librispeech/ASR/lstm_transducer_stateless/export-onnx.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../lstm_transducer_stateless2/export-onnx.py
|
1
egs/librispeech/ASR/lstm_transducer_stateless/onnx_check.py
Symbolic link
1
egs/librispeech/ASR/lstm_transducer_stateless/onnx_check.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../lstm_transducer_stateless2/onnx_check.py
|
1
egs/librispeech/ASR/lstm_transducer_stateless/onnx_pretrained.py
Symbolic link
1
egs/librispeech/ASR/lstm_transducer_stateless/onnx_pretrained.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../lstm_transducer_stateless2/onnx_pretrained.py
|
337
egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py
Executable file
337
egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py
Executable file
@ -0,0 +1,337 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
"""
|
||||||
|
Please see
|
||||||
|
https://k2-fsa.github.io/icefall/model-export/export-ncnn.html
|
||||||
|
for more details about how to use this file.
|
||||||
|
|
||||||
|
We use the pre-trained model from
|
||||||
|
https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
|
||||||
|
as an example to show how to use this file.
|
||||||
|
|
||||||
|
1. Download the pre-trained model
|
||||||
|
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||||
|
git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt"
|
||||||
|
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
2. Export via torch.jit.trace()
|
||||||
|
|
||||||
|
./lstm_transducer_stateless2/export-for-ncnn.py \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
|
||||||
|
cd ./lstm_transducer_stateless2/exp
|
||||||
|
pnnx encoder_jit_trace-pnnx.pt
|
||||||
|
pnnx decoder_jit_trace-pnnx.pt
|
||||||
|
pnnx joiner_jit_trace-pnnx.pt
|
||||||
|
|
||||||
|
See ./streaming-ncnn-decode.py
|
||||||
|
and
|
||||||
|
https://github.com/k2-fsa/sherpa-ncnn
|
||||||
|
for usage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.utils import setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=28,
|
||||||
|
help="""It specifies the checkpoint to use for averaging.
|
||||||
|
Note: Epoch counts from 0.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="pruned_transducer_stateless2/exp",
|
||||||
|
help="""It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-model",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_bpe_500/bpe.model",
|
||||||
|
help="Path to the BPE model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def export_encoder_model_jit_trace(
|
||||||
|
encoder_model: torch.nn.Module,
|
||||||
|
encoder_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given encoder model with torch.jit.trace()
|
||||||
|
|
||||||
|
Note: The warmup argument is fixed to 1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_model:
|
||||||
|
The input encoder model
|
||||||
|
encoder_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
"""
|
||||||
|
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||||
|
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||||
|
states = encoder_model.get_init_states()
|
||||||
|
|
||||||
|
traced_model = torch.jit.trace(encoder_model, (x, x_lens, states))
|
||||||
|
traced_model.save(encoder_filename)
|
||||||
|
logging.info(f"Saved to {encoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_decoder_model_jit_trace(
|
||||||
|
decoder_model: torch.nn.Module,
|
||||||
|
decoder_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given decoder model with torch.jit.trace()
|
||||||
|
|
||||||
|
Note: The argument need_pad is fixed to False.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder_model:
|
||||||
|
The input decoder model
|
||||||
|
decoder_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
"""
|
||||||
|
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||||
|
need_pad = torch.tensor([False])
|
||||||
|
|
||||||
|
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
|
||||||
|
traced_model.save(decoder_filename)
|
||||||
|
logging.info(f"Saved to {decoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_joiner_model_jit_trace(
|
||||||
|
joiner_model: torch.nn.Module,
|
||||||
|
joiner_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given joiner model with torch.jit.trace()
|
||||||
|
|
||||||
|
Note: The argument project_input is fixed to True. A user should not
|
||||||
|
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
||||||
|
will do that for the user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
joiner_model:
|
||||||
|
The input joiner model
|
||||||
|
joiner_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
|
||||||
|
"""
|
||||||
|
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||||
|
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||||
|
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||||
|
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||||
|
|
||||||
|
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
|
||||||
|
traced_model.save(joiner_filename)
|
||||||
|
logging.info(f"Saved to {joiner_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn")
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
params.is_pnnx = True
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_transducer_model(params, enable_giga=False)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to("cpu")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
|
logging.info("Using torch.jit.trace()")
|
||||||
|
|
||||||
|
logging.info("Exporting encoder")
|
||||||
|
encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt"
|
||||||
|
export_encoder_model_jit_trace(model.encoder, encoder_filename)
|
||||||
|
|
||||||
|
logging.info("Exporting decoder")
|
||||||
|
decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt"
|
||||||
|
export_decoder_model_jit_trace(model.decoder, decoder_filename)
|
||||||
|
|
||||||
|
logging.info("Exporting joiner")
|
||||||
|
joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt"
|
||||||
|
export_joiner_model_jit_trace(model.joiner, joiner_filename)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
main()
|
593
egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py
Executable file
593
egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py
Executable file
@ -0,0 +1,593 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script exports a transducer model from PyTorch to ONNX.
|
||||||
|
|
||||||
|
We use the pre-trained model from
|
||||||
|
https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
|
||||||
|
as an example to show how to use this file.
|
||||||
|
|
||||||
|
1. Download the pre-trained model
|
||||||
|
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||||
|
git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt"
|
||||||
|
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
2. Export the model to ONNX
|
||||||
|
|
||||||
|
./lstm_transducer_stateless2/export-onnx.py \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir $repo/exp
|
||||||
|
|
||||||
|
It will generate the following 3 files inside $repo/exp:
|
||||||
|
|
||||||
|
- encoder-epoch-99-avg-1.onnx
|
||||||
|
- decoder-epoch-99-avg-1.onnx
|
||||||
|
- joiner-epoch-99-avg-1.onnx
|
||||||
|
|
||||||
|
See ./onnx_pretrained.py and ./onnx_check.py for how to
|
||||||
|
use the exported ONNX models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from decoder import Decoder
|
||||||
|
from lstm import RNN
|
||||||
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.utils import setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=28,
|
||||||
|
help="""It specifies the checkpoint to use for averaging.
|
||||||
|
Note: Epoch counts from 0.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="pruned_transducer_stateless5/exp",
|
||||||
|
help="""It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-model",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_bpe_500/bpe.model",
|
||||||
|
help="Path to the BPE model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||||
|
"""Add meta data to an ONNX model. It is changed in-place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Filename of the ONNX model to be changed.
|
||||||
|
meta_data:
|
||||||
|
Key-value pairs.
|
||||||
|
"""
|
||||||
|
model = onnx.load(filename)
|
||||||
|
for key, value in meta_data.items():
|
||||||
|
meta = model.metadata_props.add()
|
||||||
|
meta.key = key
|
||||||
|
meta.value = value
|
||||||
|
|
||||||
|
onnx.save(model, filename)
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxEncoder(nn.Module):
|
||||||
|
"""A wrapper for RNN and the encoder_proj from the joiner"""
|
||||||
|
|
||||||
|
def __init__(self, encoder: RNN, encoder_proj: nn.Linear):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder:
|
||||||
|
An RNN encoder.
|
||||||
|
encoder_proj:
|
||||||
|
The projection layer for encoder from the joiner.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = encoder
|
||||||
|
self.encoder_proj = encoder_proj
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Please see the help information of RNN.forward
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A 3-D tensor of shape (N, T, C)
|
||||||
|
states:
|
||||||
|
A tuple of 2 tensors (optional). It is for streaming inference.
|
||||||
|
states[0] is the hidden states of all layers,
|
||||||
|
with shape of (num_layers, N, d_model);
|
||||||
|
states[1] is the cell states of all layers,
|
||||||
|
with shape of (num_layers, N, rnn_hidden_size).
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing:
|
||||||
|
- encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
|
||||||
|
- updated states, whose shape is the same as the input states.
|
||||||
|
"""
|
||||||
|
N = x.size(0)
|
||||||
|
T = x.size(1)
|
||||||
|
x_lens = torch.tensor([T] * N, dtype=torch.int64, device=x.device)
|
||||||
|
encoder_out, _, next_states = self.encoder(x, x_lens, states)
|
||||||
|
|
||||||
|
encoder_out = self.encoder_proj(encoder_out)
|
||||||
|
# Now encoder_out is of shape (N, T, joiner_dim)
|
||||||
|
|
||||||
|
return encoder_out, next_states
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxDecoder(nn.Module):
|
||||||
|
"""A wrapper for Decoder and the decoder_proj from the joiner"""
|
||||||
|
|
||||||
|
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
|
||||||
|
super().__init__()
|
||||||
|
self.decoder = decoder
|
||||||
|
self.decoder_proj = decoder_proj
|
||||||
|
|
||||||
|
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
y:
|
||||||
|
A 2-D tensor of shape (N, context_size).
|
||||||
|
Returns
|
||||||
|
Return a 2-D tensor of shape (N, joiner_dim)
|
||||||
|
"""
|
||||||
|
need_pad = False
|
||||||
|
decoder_output = self.decoder(y, need_pad=need_pad)
|
||||||
|
decoder_output = decoder_output.squeeze(1)
|
||||||
|
output = self.decoder_proj(decoder_output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxJoiner(nn.Module):
|
||||||
|
"""A wrapper for the joiner"""
|
||||||
|
|
||||||
|
def __init__(self, output_linear: nn.Linear):
|
||||||
|
super().__init__()
|
||||||
|
self.output_linear = output_linear
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
decoder_out: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder_out:
|
||||||
|
A 2-D tensor of shape (N, joiner_dim)
|
||||||
|
decoder_out:
|
||||||
|
A 2-D tensor of shape (N, joiner_dim)
|
||||||
|
Returns:
|
||||||
|
Return a 2-D tensor of shape (N, vocab_size)
|
||||||
|
"""
|
||||||
|
logit = encoder_out + decoder_out
|
||||||
|
logit = self.output_linear(torch.tanh(logit))
|
||||||
|
return logit
|
||||||
|
|
||||||
|
|
||||||
|
def export_encoder_model_onnx(
|
||||||
|
encoder_model: OnnxEncoder,
|
||||||
|
encoder_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given encoder model to ONNX format.
|
||||||
|
The exported model has the following inputs:
|
||||||
|
|
||||||
|
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
||||||
|
- state0, a tensor of shape (num_encoder_layers, batch_size, d_model)
|
||||||
|
- state1, a tensor of shape (num_encoder_layers, batch_size, rnn_hidden_size)
|
||||||
|
|
||||||
|
and it has 3 outputs:
|
||||||
|
|
||||||
|
- encoder_out, a tensor of shape (N, T', joiner_dim)
|
||||||
|
- new_state0, a tensor of shape (num_encoder_layers, batch_size, d_model)
|
||||||
|
- new_state1, a tensor of shape (num_encoder_layers, batch_size, rnn_hidden_size)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_model:
|
||||||
|
The input encoder model
|
||||||
|
encoder_filename:
|
||||||
|
The filename to save the exported ONNX model.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
num_encoder_layers = encoder_model.encoder.num_encoder_layers
|
||||||
|
d_model = encoder_model.encoder.d_model
|
||||||
|
rnn_hidden_size = encoder_model.encoder.rnn_hidden_size
|
||||||
|
|
||||||
|
decode_chunk_len = 4
|
||||||
|
T = 9
|
||||||
|
|
||||||
|
x = torch.zeros(1, T, 80, dtype=torch.float32)
|
||||||
|
states = encoder_model.encoder.get_init_states()
|
||||||
|
# state0: (num_encoder_layers, batch_size, d_model)
|
||||||
|
# state1: (num_encoder_layers, batch_size, rnn_hidden_size)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
encoder_model,
|
||||||
|
(x, states),
|
||||||
|
encoder_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["x", "state0", "state1"],
|
||||||
|
output_names=["encoder_out", "new_state0", "new_state1"],
|
||||||
|
dynamic_axes={
|
||||||
|
"x": {0: "N"},
|
||||||
|
"state0": {1: "N"},
|
||||||
|
"state1": {1: "N"},
|
||||||
|
"encoder_out": {0: "N"},
|
||||||
|
"new_state0": {1: "N"},
|
||||||
|
"new_state1": {1: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
meta_data = {
|
||||||
|
"model_type": "lstm",
|
||||||
|
"version": "1",
|
||||||
|
"model_author": "k2-fsa",
|
||||||
|
"decode_chunk_len": str(decode_chunk_len), # 32
|
||||||
|
"T": str(T), # 39
|
||||||
|
"num_encoder_layers": str(num_encoder_layers),
|
||||||
|
"d_model": str(d_model),
|
||||||
|
"rnn_hidden_size": str(rnn_hidden_size),
|
||||||
|
}
|
||||||
|
logging.info(f"meta_data: {meta_data}")
|
||||||
|
|
||||||
|
add_meta_data(filename=encoder_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
def export_decoder_model_onnx(
|
||||||
|
decoder_model: OnnxDecoder,
|
||||||
|
decoder_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the decoder model to ONNX format.
|
||||||
|
|
||||||
|
The exported model has one input:
|
||||||
|
|
||||||
|
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||||
|
|
||||||
|
and has one output:
|
||||||
|
|
||||||
|
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder_model:
|
||||||
|
The decoder model to be exported.
|
||||||
|
decoder_filename:
|
||||||
|
Filename to save the exported ONNX model.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
context_size = decoder_model.decoder.context_size
|
||||||
|
vocab_size = decoder_model.decoder.vocab_size
|
||||||
|
|
||||||
|
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||||
|
torch.onnx.export(
|
||||||
|
decoder_model,
|
||||||
|
y,
|
||||||
|
decoder_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["y"],
|
||||||
|
output_names=["decoder_out"],
|
||||||
|
dynamic_axes={
|
||||||
|
"y": {0: "N"},
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
meta_data = {
|
||||||
|
"context_size": str(context_size),
|
||||||
|
"vocab_size": str(vocab_size),
|
||||||
|
}
|
||||||
|
add_meta_data(filename=decoder_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
def export_joiner_model_onnx(
|
||||||
|
joiner_model: nn.Module,
|
||||||
|
joiner_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the joiner model to ONNX format.
|
||||||
|
The exported joiner model has two inputs:
|
||||||
|
|
||||||
|
- encoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
- decoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
|
||||||
|
and produces one output:
|
||||||
|
|
||||||
|
- logit: a tensor of shape (N, vocab_size)
|
||||||
|
"""
|
||||||
|
joiner_dim = joiner_model.output_linear.weight.shape[1]
|
||||||
|
logging.info(f"joiner dim: {joiner_dim}")
|
||||||
|
|
||||||
|
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||||
|
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
joiner_model,
|
||||||
|
(projected_encoder_out, projected_decoder_out),
|
||||||
|
joiner_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=[
|
||||||
|
"encoder_out",
|
||||||
|
"decoder_out",
|
||||||
|
],
|
||||||
|
output_names=["logit"],
|
||||||
|
dynamic_axes={
|
||||||
|
"encoder_out": {0: "N"},
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
"logit": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
meta_data = {
|
||||||
|
"joiner_dim": str(joiner_dim),
|
||||||
|
}
|
||||||
|
add_meta_data(filename=joiner_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
setup_logger(f"{params.exp_dir}/log-export/log-export-onnx")
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_transducer_model(params, enable_giga=False)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to("cpu")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
|
||||||
|
|
||||||
|
encoder = OnnxEncoder(
|
||||||
|
encoder=model.encoder,
|
||||||
|
encoder_proj=model.joiner.encoder_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder = OnnxDecoder(
|
||||||
|
decoder=model.decoder,
|
||||||
|
decoder_proj=model.joiner.decoder_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
|
||||||
|
|
||||||
|
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
|
||||||
|
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
|
||||||
|
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
|
||||||
|
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
|
||||||
|
logging.info(f"encoder parameters: {encoder_num_param}")
|
||||||
|
logging.info(f"decoder parameters: {decoder_num_param}")
|
||||||
|
logging.info(f"joiner parameters: {joiner_num_param}")
|
||||||
|
logging.info(f"total parameters: {total_num_param}")
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
suffix = f"iter-{params.iter}"
|
||||||
|
else:
|
||||||
|
suffix = f"epoch-{params.epoch}"
|
||||||
|
|
||||||
|
suffix += f"-avg-{params.avg}"
|
||||||
|
|
||||||
|
opset_version = 13
|
||||||
|
|
||||||
|
logging.info("Exporting encoder")
|
||||||
|
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
|
||||||
|
export_encoder_model_onnx(
|
||||||
|
encoder,
|
||||||
|
encoder_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported encoder to {encoder_filename}")
|
||||||
|
|
||||||
|
logging.info("Exporting decoder")
|
||||||
|
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
|
||||||
|
export_decoder_model_onnx(
|
||||||
|
decoder,
|
||||||
|
decoder_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported decoder to {decoder_filename}")
|
||||||
|
|
||||||
|
logging.info("Exporting joiner")
|
||||||
|
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
|
||||||
|
export_joiner_model_onnx(
|
||||||
|
joiner,
|
||||||
|
joiner_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported joiner to {joiner_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
main()
|
@ -74,29 +74,6 @@ with the following commands:
|
|||||||
git lfs install
|
git lfs install
|
||||||
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
|
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
|
||||||
# You will find the pre-trained models in icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp
|
# You will find the pre-trained models in icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp
|
||||||
|
|
||||||
(3) Export to ONNX format
|
|
||||||
|
|
||||||
./lstm_transducer_stateless2/export.py \
|
|
||||||
--exp-dir ./lstm_transducer_stateless2/exp \
|
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
|
||||||
--epoch 20 \
|
|
||||||
--avg 10 \
|
|
||||||
--onnx 1
|
|
||||||
|
|
||||||
It will generate the following files in the given `exp_dir`.
|
|
||||||
|
|
||||||
- encoder.onnx
|
|
||||||
- decoder.onnx
|
|
||||||
- joiner.onnx
|
|
||||||
- joiner_encoder_proj.onnx
|
|
||||||
- joiner_decoder_proj.onnx
|
|
||||||
|
|
||||||
Please see ./streaming-onnx-decode.py for usage of the generated files
|
|
||||||
|
|
||||||
Check
|
|
||||||
https://github.com/k2-fsa/sherpa-onnx
|
|
||||||
for how to use the exported models outside of icefall.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -192,35 +169,6 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--pnnx",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""True to save a model after applying torch.jit.trace for later
|
|
||||||
converting to PNNX. It will generate 3 files:
|
|
||||||
- encoder_jit_trace-pnnx.pt
|
|
||||||
- decoder_jit_trace-pnnx.pt
|
|
||||||
- joiner_jit_trace-pnnx.pt
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--onnx",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""If True, --jit and --pnnx are ignored and it exports the model
|
|
||||||
to onnx format. It will generate the following files:
|
|
||||||
|
|
||||||
- encoder.onnx
|
|
||||||
- decoder.onnx
|
|
||||||
- joiner.onnx
|
|
||||||
- joiner_encoder_proj.onnx
|
|
||||||
- joiner_decoder_proj.onnx
|
|
||||||
|
|
||||||
Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -305,209 +253,6 @@ def export_joiner_model_jit_trace(
|
|||||||
logging.info(f"Saved to {joiner_filename}")
|
logging.info(f"Saved to {joiner_filename}")
|
||||||
|
|
||||||
|
|
||||||
def export_encoder_model_onnx(
|
|
||||||
encoder_model: nn.Module,
|
|
||||||
encoder_filename: str,
|
|
||||||
opset_version: int = 11,
|
|
||||||
) -> None:
|
|
||||||
"""Export the given encoder model to ONNX format.
|
|
||||||
The exported model has 3 inputs:
|
|
||||||
|
|
||||||
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
|
||||||
- x_lens, a tensor of shape (N,); dtype is torch.int64
|
|
||||||
- states: a tuple containing:
|
|
||||||
- h0: a tensor of shape (num_layers, N, proj_size)
|
|
||||||
- c0: a tensor of shape (num_layers, N, hidden_size)
|
|
||||||
|
|
||||||
and it has 3 outputs:
|
|
||||||
|
|
||||||
- encoder_out, a tensor of shape (N, T, C)
|
|
||||||
- encoder_out_lens, a tensor of shape (N,)
|
|
||||||
- states: a tuple containing:
|
|
||||||
- next_h0: a tensor of shape (num_layers, N, proj_size)
|
|
||||||
- next_c0: a tensor of shape (num_layers, N, hidden_size)
|
|
||||||
|
|
||||||
Note: The warmup argument is fixed to 1.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
encoder_model:
|
|
||||||
The input encoder model
|
|
||||||
encoder_filename:
|
|
||||||
The filename to save the exported ONNX model.
|
|
||||||
opset_version:
|
|
||||||
The opset version to use.
|
|
||||||
"""
|
|
||||||
N = 1
|
|
||||||
x = torch.zeros(N, 9, 80, dtype=torch.float32)
|
|
||||||
x_lens = torch.tensor([9], dtype=torch.int64)
|
|
||||||
h = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.d_model)
|
|
||||||
c = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.rnn_hidden_size)
|
|
||||||
|
|
||||||
warmup = 1.0
|
|
||||||
torch.onnx.export(
|
|
||||||
encoder_model, # use torch.jit.trace() internally
|
|
||||||
(x, x_lens, (h, c), warmup),
|
|
||||||
encoder_filename,
|
|
||||||
verbose=False,
|
|
||||||
opset_version=opset_version,
|
|
||||||
input_names=["x", "x_lens", "h", "c", "warmup"],
|
|
||||||
output_names=["encoder_out", "encoder_out_lens", "next_h", "next_c"],
|
|
||||||
dynamic_axes={
|
|
||||||
"x": {0: "N", 1: "T"},
|
|
||||||
"x_lens": {0: "N"},
|
|
||||||
"h": {1: "N"},
|
|
||||||
"c": {1: "N"},
|
|
||||||
"encoder_out": {0: "N", 1: "T"},
|
|
||||||
"encoder_out_lens": {0: "N"},
|
|
||||||
"next_h": {1: "N"},
|
|
||||||
"next_c": {1: "N"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
logging.info(f"Saved to {encoder_filename}")
|
|
||||||
|
|
||||||
|
|
||||||
def export_decoder_model_onnx(
|
|
||||||
decoder_model: nn.Module,
|
|
||||||
decoder_filename: str,
|
|
||||||
opset_version: int = 11,
|
|
||||||
) -> None:
|
|
||||||
"""Export the decoder model to ONNX format.
|
|
||||||
|
|
||||||
The exported model has one input:
|
|
||||||
|
|
||||||
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
|
||||||
|
|
||||||
and has one output:
|
|
||||||
|
|
||||||
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
|
|
||||||
|
|
||||||
Note: The argument need_pad is fixed to False.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
decoder_model:
|
|
||||||
The decoder model to be exported.
|
|
||||||
decoder_filename:
|
|
||||||
Filename to save the exported ONNX model.
|
|
||||||
opset_version:
|
|
||||||
The opset version to use.
|
|
||||||
"""
|
|
||||||
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
|
||||||
need_pad = False # Always False, so we can use torch.jit.trace() here
|
|
||||||
# Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script()
|
|
||||||
# in this case
|
|
||||||
torch.onnx.export(
|
|
||||||
decoder_model,
|
|
||||||
(y, need_pad),
|
|
||||||
decoder_filename,
|
|
||||||
verbose=False,
|
|
||||||
opset_version=opset_version,
|
|
||||||
input_names=["y", "need_pad"],
|
|
||||||
output_names=["decoder_out"],
|
|
||||||
dynamic_axes={
|
|
||||||
"y": {0: "N"},
|
|
||||||
"decoder_out": {0: "N"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
logging.info(f"Saved to {decoder_filename}")
|
|
||||||
|
|
||||||
|
|
||||||
def export_joiner_model_onnx(
|
|
||||||
joiner_model: nn.Module,
|
|
||||||
joiner_filename: str,
|
|
||||||
opset_version: int = 11,
|
|
||||||
) -> None:
|
|
||||||
"""Export the joiner model to ONNX format.
|
|
||||||
The exported joiner model has two inputs:
|
|
||||||
|
|
||||||
- projected_encoder_out: a tensor of shape (N, joiner_dim)
|
|
||||||
- projected_decoder_out: a tensor of shape (N, joiner_dim)
|
|
||||||
|
|
||||||
and produces one output:
|
|
||||||
|
|
||||||
- logit: a tensor of shape (N, vocab_size)
|
|
||||||
|
|
||||||
The exported encoder_proj model has one input:
|
|
||||||
|
|
||||||
- encoder_out: a tensor of shape (N, encoder_out_dim)
|
|
||||||
|
|
||||||
and produces one output:
|
|
||||||
|
|
||||||
- projected_encoder_out: a tensor of shape (N, joiner_dim)
|
|
||||||
|
|
||||||
The exported decoder_proj model has one input:
|
|
||||||
|
|
||||||
- decoder_out: a tensor of shape (N, decoder_out_dim)
|
|
||||||
|
|
||||||
and produces one output:
|
|
||||||
|
|
||||||
- projected_decoder_out: a tensor of shape (N, joiner_dim)
|
|
||||||
"""
|
|
||||||
encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
|
|
||||||
|
|
||||||
decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
|
|
||||||
|
|
||||||
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
|
||||||
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
|
||||||
joiner_dim = joiner_model.decoder_proj.weight.shape[0]
|
|
||||||
|
|
||||||
projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
|
|
||||||
projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
|
|
||||||
|
|
||||||
project_input = False
|
|
||||||
# Note: It uses torch.jit.trace() internally
|
|
||||||
torch.onnx.export(
|
|
||||||
joiner_model,
|
|
||||||
(projected_encoder_out, projected_decoder_out, project_input),
|
|
||||||
joiner_filename,
|
|
||||||
verbose=False,
|
|
||||||
opset_version=opset_version,
|
|
||||||
input_names=[
|
|
||||||
"projected_encoder_out",
|
|
||||||
"projected_decoder_out",
|
|
||||||
"project_input",
|
|
||||||
],
|
|
||||||
output_names=["logit"],
|
|
||||||
dynamic_axes={
|
|
||||||
"projected_encoder_out": {0: "N"},
|
|
||||||
"projected_decoder_out": {0: "N"},
|
|
||||||
"logit": {0: "N"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
logging.info(f"Saved to {joiner_filename}")
|
|
||||||
|
|
||||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
|
||||||
torch.onnx.export(
|
|
||||||
joiner_model.encoder_proj,
|
|
||||||
encoder_out,
|
|
||||||
encoder_proj_filename,
|
|
||||||
verbose=False,
|
|
||||||
opset_version=opset_version,
|
|
||||||
input_names=["encoder_out"],
|
|
||||||
output_names=["projected_encoder_out"],
|
|
||||||
dynamic_axes={
|
|
||||||
"encoder_out": {0: "N"},
|
|
||||||
"projected_encoder_out": {0: "N"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
logging.info(f"Saved to {encoder_proj_filename}")
|
|
||||||
|
|
||||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
|
||||||
torch.onnx.export(
|
|
||||||
joiner_model.decoder_proj,
|
|
||||||
decoder_out,
|
|
||||||
decoder_proj_filename,
|
|
||||||
verbose=False,
|
|
||||||
opset_version=opset_version,
|
|
||||||
input_names=["decoder_out"],
|
|
||||||
output_names=["projected_decoder_out"],
|
|
||||||
dynamic_axes={
|
|
||||||
"decoder_out": {0: "N"},
|
|
||||||
"projected_decoder_out": {0: "N"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
logging.info(f"Saved to {decoder_proj_filename}")
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
@ -531,10 +276,6 @@ def main():
|
|||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
if params.pnnx:
|
|
||||||
params.is_pnnx = params.pnnx
|
|
||||||
logging.info("For PNNX")
|
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_transducer_model(params, enable_giga=False)
|
model = get_transducer_model(params, enable_giga=False)
|
||||||
|
|
||||||
@ -629,44 +370,7 @@ def main():
|
|||||||
model.to("cpu")
|
model.to("cpu")
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if params.onnx:
|
if params.jit_trace is True:
|
||||||
logging.info("Export model to ONNX format")
|
|
||||||
convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
|
|
||||||
|
|
||||||
opset_version = 11
|
|
||||||
encoder_filename = params.exp_dir / "encoder.onnx"
|
|
||||||
export_encoder_model_onnx(
|
|
||||||
model.encoder,
|
|
||||||
encoder_filename,
|
|
||||||
opset_version=opset_version,
|
|
||||||
)
|
|
||||||
|
|
||||||
decoder_filename = params.exp_dir / "decoder.onnx"
|
|
||||||
export_decoder_model_onnx(
|
|
||||||
model.decoder,
|
|
||||||
decoder_filename,
|
|
||||||
opset_version=opset_version,
|
|
||||||
)
|
|
||||||
|
|
||||||
joiner_filename = params.exp_dir / "joiner.onnx"
|
|
||||||
export_joiner_model_onnx(
|
|
||||||
model.joiner,
|
|
||||||
joiner_filename,
|
|
||||||
opset_version=opset_version,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif params.pnnx:
|
|
||||||
convert_scaled_to_non_scaled(model, inplace=True)
|
|
||||||
logging.info("Using torch.jit.trace()")
|
|
||||||
encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt"
|
|
||||||
export_encoder_model_jit_trace(model.encoder, encoder_filename)
|
|
||||||
|
|
||||||
decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt"
|
|
||||||
export_decoder_model_jit_trace(model.decoder, decoder_filename)
|
|
||||||
|
|
||||||
joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt"
|
|
||||||
export_joiner_model_jit_trace(model.joiner, joiner_filename)
|
|
||||||
elif params.jit_trace is True:
|
|
||||||
convert_scaled_to_non_scaled(model, inplace=True)
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
logging.info("Using torch.jit.trace()")
|
logging.info("Using torch.jit.trace()")
|
||||||
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
|
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
|
||||||
|
@ -19,7 +19,7 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
./lstm_transducer_stateless2/ncnn-decode.py \
|
./lstm_transducer_stateless2/ncnn-decode.py \
|
||||||
--bpe-model-filename ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \
|
--encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \
|
||||||
--encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \
|
--encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \
|
||||||
--decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \
|
--decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \
|
||||||
@ -27,15 +27,19 @@ Usage:
|
|||||||
--joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \
|
--joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \
|
||||||
--joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \
|
--joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \
|
||||||
./test_wavs/1089-134686-0001.wav
|
./test_wavs/1089-134686-0001.wav
|
||||||
|
|
||||||
|
Please see
|
||||||
|
https://k2-fsa.github.io/icefall/model-export/export-ncnn.html
|
||||||
|
for details.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import ncnn
|
import ncnn
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
@ -44,9 +48,9 @@ def get_args():
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model-filename",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="Path to bpe.model",
|
help="Path to tokens.txt",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -240,9 +244,6 @@ def main():
|
|||||||
|
|
||||||
model = Model(args)
|
model = Model(args)
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
|
||||||
sp.load(args.bpe_model_filename)
|
|
||||||
|
|
||||||
sound_file = args.sound_filename
|
sound_file = args.sound_filename
|
||||||
|
|
||||||
sample_rate = 16000
|
sample_rate = 16000
|
||||||
@ -280,8 +281,16 @@ def main():
|
|||||||
|
|
||||||
encoder_out, encoder_out_lens, hx, cx = model.run_encoder(features, states)
|
encoder_out, encoder_out_lens, hx, cx = model.run_encoder(features, states)
|
||||||
hyp = greedy_search(model, encoder_out)
|
hyp = greedy_search(model, encoder_out)
|
||||||
|
|
||||||
|
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
||||||
|
|
||||||
|
text = ""
|
||||||
|
for i in hyp:
|
||||||
|
text += symbol_table[i]
|
||||||
|
text = text.replace("▁", " ").strip()
|
||||||
|
|
||||||
logging.info(sound_file)
|
logging.info(sound_file)
|
||||||
logging.info(sp.decode(hyp))
|
logging.info(text)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
261
egs/librispeech/ASR/lstm_transducer_stateless2/onnx_check.py
Executable file
261
egs/librispeech/ASR/lstm_transducer_stateless2/onnx_check.py
Executable file
@ -0,0 +1,261 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script checks that exported ONNX models produce the same output
|
||||||
|
with the given torchscript model for the same input.
|
||||||
|
|
||||||
|
We use the pre-trained model from
|
||||||
|
https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
|
||||||
|
as an example to show how to use this file.
|
||||||
|
|
||||||
|
1. Download the pre-trained model
|
||||||
|
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||||
|
git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt"
|
||||||
|
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
2. Export the model via torch.jit.trace()
|
||||||
|
|
||||||
|
./lstm_transducer_stateless2/export.py \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir $repo/exp/ \
|
||||||
|
--jit-trace 1
|
||||||
|
|
||||||
|
It will generate the following 3 files inside $repo/exp
|
||||||
|
|
||||||
|
- encoder_jit_trace.pt
|
||||||
|
- decoder_jit_trace.pt
|
||||||
|
- joiner_jit_trace.pt
|
||||||
|
|
||||||
|
3. Export the model to ONNX
|
||||||
|
|
||||||
|
./lstm_transducer_stateless2/export-onnx.py \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir $repo/exp
|
||||||
|
|
||||||
|
|
||||||
|
It will generate the following 3 files inside $repo/exp:
|
||||||
|
|
||||||
|
- encoder-epoch-99-avg-1.onnx
|
||||||
|
- decoder-epoch-99-avg-1.onnx
|
||||||
|
- joiner-epoch-99-avg-1.onnx
|
||||||
|
|
||||||
|
4. Run this file
|
||||||
|
|
||||||
|
./lstm_transducer_stateless2/onnx_check.py \
|
||||||
|
--jit-encoder-filename $repo/exp/encoder_jit_trace.pt \
|
||||||
|
--jit-decoder-filename $repo/exp/decoder_jit_trace.pt \
|
||||||
|
--jit-joiner-filename $repo/exp/joiner_jit_trace.pt \
|
||||||
|
--onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||||
|
--onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||||
|
--onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from onnx_pretrained import OnnxModel
|
||||||
|
|
||||||
|
from icefall import is_module_available
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--jit-encoder-filename",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="Path to the torchscript encoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--jit-decoder-filename",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="Path to the torchscript decoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--jit-joiner-filename",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="Path to the torchscript joiner model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--onnx-encoder-filename",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="Path to the ONNX encoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--onnx-decoder-filename",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="Path to the ONNX decoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--onnx-joiner-filename",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="Path to the ONNX joiner model",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoder(
|
||||||
|
torch_encoder_model: torch.jit.ScriptModule,
|
||||||
|
torch_encoder_proj_model: torch.jit.ScriptModule,
|
||||||
|
onnx_model: OnnxModel,
|
||||||
|
):
|
||||||
|
N = torch.randint(1, 100, size=(1,)).item()
|
||||||
|
T = onnx_model.segment
|
||||||
|
C = 80
|
||||||
|
x_lens = torch.tensor([T] * N)
|
||||||
|
torch_states = torch_encoder_model.get_init_states(N)
|
||||||
|
|
||||||
|
onnx_model.init_encoder_states(N)
|
||||||
|
|
||||||
|
for i in range(5):
|
||||||
|
logging.info(f"test_encoder: iter {i}")
|
||||||
|
x = torch.rand(N, T, C)
|
||||||
|
torch_encoder_out, _, torch_states = torch_encoder_model(
|
||||||
|
x, x_lens, torch_states
|
||||||
|
)
|
||||||
|
torch_encoder_out = torch_encoder_proj_model(torch_encoder_out)
|
||||||
|
|
||||||
|
onnx_encoder_out = onnx_model.run_encoder(x)
|
||||||
|
|
||||||
|
assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-4), (
|
||||||
|
(torch_encoder_out - onnx_encoder_out).abs().max()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_decoder(
|
||||||
|
torch_decoder_model: torch.jit.ScriptModule,
|
||||||
|
torch_decoder_proj_model: torch.jit.ScriptModule,
|
||||||
|
onnx_model: OnnxModel,
|
||||||
|
):
|
||||||
|
context_size = onnx_model.context_size
|
||||||
|
vocab_size = onnx_model.vocab_size
|
||||||
|
for i in range(10):
|
||||||
|
N = torch.randint(1, 100, size=(1,)).item()
|
||||||
|
logging.info(f"test_decoder: iter {i}, N={N}")
|
||||||
|
x = torch.randint(
|
||||||
|
low=1,
|
||||||
|
high=vocab_size,
|
||||||
|
size=(N, context_size),
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
torch_decoder_out = torch_decoder_model(x, need_pad=torch.tensor([False]))
|
||||||
|
torch_decoder_out = torch_decoder_proj_model(torch_decoder_out)
|
||||||
|
torch_decoder_out = torch_decoder_out.squeeze(1)
|
||||||
|
|
||||||
|
onnx_decoder_out = onnx_model.run_decoder(x)
|
||||||
|
assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), (
|
||||||
|
(torch_decoder_out - onnx_decoder_out).abs().max()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_joiner(
|
||||||
|
torch_joiner_model: torch.jit.ScriptModule,
|
||||||
|
onnx_model: OnnxModel,
|
||||||
|
):
|
||||||
|
encoder_dim = torch_joiner_model.encoder_proj.weight.shape[1]
|
||||||
|
decoder_dim = torch_joiner_model.decoder_proj.weight.shape[1]
|
||||||
|
for i in range(10):
|
||||||
|
N = torch.randint(1, 100, size=(1,)).item()
|
||||||
|
logging.info(f"test_joiner: iter {i}, N={N}")
|
||||||
|
encoder_out = torch.rand(N, encoder_dim)
|
||||||
|
decoder_out = torch.rand(N, decoder_dim)
|
||||||
|
|
||||||
|
projected_encoder_out = torch_joiner_model.encoder_proj(encoder_out)
|
||||||
|
projected_decoder_out = torch_joiner_model.decoder_proj(decoder_out)
|
||||||
|
|
||||||
|
torch_joiner_out = torch_joiner_model(encoder_out, decoder_out)
|
||||||
|
onnx_joiner_out = onnx_model.run_joiner(
|
||||||
|
projected_encoder_out, projected_decoder_out
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), (
|
||||||
|
(torch_joiner_out - onnx_joiner_out).abs().max()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
|
||||||
|
torch_encoder_model = torch.jit.load(args.jit_encoder_filename)
|
||||||
|
torch_decoder_model = torch.jit.load(args.jit_decoder_filename)
|
||||||
|
torch_joiner_model = torch.jit.load(args.jit_joiner_filename)
|
||||||
|
|
||||||
|
onnx_model = OnnxModel(
|
||||||
|
encoder_model_filename=args.onnx_encoder_filename,
|
||||||
|
decoder_model_filename=args.onnx_decoder_filename,
|
||||||
|
joiner_model_filename=args.onnx_joiner_filename,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Test encoder")
|
||||||
|
# When exporting the model to onnx, we have already put the encoder_proj
|
||||||
|
# inside the encoder.
|
||||||
|
test_encoder(torch_encoder_model, torch_joiner_model.encoder_proj, onnx_model)
|
||||||
|
|
||||||
|
logging.info("Test decoder")
|
||||||
|
# When exporting the model to onnx, we have already put the decoder_proj
|
||||||
|
# inside the decoder.
|
||||||
|
test_decoder(torch_decoder_model, torch_joiner_model.decoder_proj, onnx_model)
|
||||||
|
|
||||||
|
logging.info("Test joiner")
|
||||||
|
test_joiner(torch_joiner_model, onnx_model)
|
||||||
|
|
||||||
|
logging.info("Finished checking ONNX models")
|
||||||
|
|
||||||
|
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
# See https://github.com/pytorch/pytorch/issues/38342
|
||||||
|
# and https://github.com/pytorch/pytorch/issues/33354
|
||||||
|
#
|
||||||
|
# If we don't do this, the delay increases whenever there is
|
||||||
|
# a new request that changes the actual batch size.
|
||||||
|
# If you use `py-spy dump --pid <server-pid> --native`, you will
|
||||||
|
# see a lot of time is spent in re-compiling the torch script model.
|
||||||
|
torch._C._jit_set_profiling_executor(False)
|
||||||
|
torch._C._jit_set_profiling_mode(False)
|
||||||
|
torch._C._set_graph_executor_optimize(False)
|
||||||
|
if __name__ == "__main__":
|
||||||
|
torch.manual_seed(20230207)
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
428
egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py
Executable file
428
egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py
Executable file
@ -0,0 +1,428 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script loads ONNX models exported by ./export-onnx.py
|
||||||
|
and uses them to decode waves.
|
||||||
|
|
||||||
|
We use the pre-trained model from
|
||||||
|
https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
|
||||||
|
as an example to show how to use this file.
|
||||||
|
|
||||||
|
1. Download the pre-trained model
|
||||||
|
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||||
|
git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt"
|
||||||
|
cd exp
|
||||||
|
ln -s exp/pretrained-iter-468000-avg-16.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
2. Export the model to ONNX
|
||||||
|
|
||||||
|
./lstm_transducer_stateless2/export-onnx.py \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir $repo/exp
|
||||||
|
|
||||||
|
It will generate the following 3 files inside $repo/exp:
|
||||||
|
|
||||||
|
- encoder-epoch-99-avg-1.onnx
|
||||||
|
- decoder-epoch-99-avg-1.onnx
|
||||||
|
- joiner-epoch-99-avg-1.onnx
|
||||||
|
|
||||||
|
3. Run this file with the exported ONNX models
|
||||||
|
|
||||||
|
./lstm_transducer_stateless2/onnx_pretrained.py \
|
||||||
|
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||||
|
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||||
|
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
$repo/test_wavs/1221-135766-0001.wav
|
||||||
|
|
||||||
|
Note: Even though this script only supports decoding a single file,
|
||||||
|
the exported ONNX models do support batch processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime as ort
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the encoder onnx model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the decoder onnx model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--joiner-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the joiner onnx model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=str,
|
||||||
|
help="""Path to tokens.txt.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_file",
|
||||||
|
type=str,
|
||||||
|
help="The input sound file to transcribe. "
|
||||||
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
|
"For example, wav and flac are supported. "
|
||||||
|
"The sample rate has to be 16kHz.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxModel:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder_model_filename: str,
|
||||||
|
decoder_model_filename: str,
|
||||||
|
joiner_model_filename: str,
|
||||||
|
):
|
||||||
|
session_opts = ort.SessionOptions()
|
||||||
|
session_opts.inter_op_num_threads = 1
|
||||||
|
session_opts.intra_op_num_threads = 1
|
||||||
|
|
||||||
|
self.session_opts = session_opts
|
||||||
|
|
||||||
|
self.init_encoder(encoder_model_filename)
|
||||||
|
self.init_decoder(decoder_model_filename)
|
||||||
|
self.init_joiner(joiner_model_filename)
|
||||||
|
|
||||||
|
def init_encoder(self, encoder_model_filename: str):
|
||||||
|
self.encoder = ort.InferenceSession(
|
||||||
|
encoder_model_filename,
|
||||||
|
sess_options=self.session_opts,
|
||||||
|
)
|
||||||
|
self.init_encoder_states()
|
||||||
|
|
||||||
|
def init_encoder_states(self, batch_size: int = 1):
|
||||||
|
encoder_meta = self.encoder.get_modelmeta().custom_metadata_map
|
||||||
|
|
||||||
|
model_type = encoder_meta["model_type"]
|
||||||
|
assert model_type == "lstm", model_type
|
||||||
|
|
||||||
|
decode_chunk_len = int(encoder_meta["decode_chunk_len"])
|
||||||
|
T = int(encoder_meta["T"])
|
||||||
|
|
||||||
|
num_encoder_layers = int(encoder_meta["num_encoder_layers"])
|
||||||
|
d_model = int(encoder_meta["d_model"])
|
||||||
|
rnn_hidden_size = int(encoder_meta["rnn_hidden_size"])
|
||||||
|
|
||||||
|
logging.info(f"decode_chunk_len: {decode_chunk_len}")
|
||||||
|
logging.info(f"T: {T}")
|
||||||
|
logging.info(f"num_encoder_layers: {num_encoder_layers}")
|
||||||
|
logging.info(f"d_model: {d_model}")
|
||||||
|
logging.info(f"rnn_hidden_size: {rnn_hidden_size}")
|
||||||
|
|
||||||
|
N = batch_size
|
||||||
|
|
||||||
|
s0 = torch.zeros(num_encoder_layers, N, d_model)
|
||||||
|
s1 = torch.zeros(num_encoder_layers, N, rnn_hidden_size)
|
||||||
|
states = [s0.numpy(), s1.numpy()]
|
||||||
|
|
||||||
|
self.states = states
|
||||||
|
|
||||||
|
self.segment = T
|
||||||
|
self.offset = decode_chunk_len
|
||||||
|
|
||||||
|
def init_decoder(self, decoder_model_filename: str):
|
||||||
|
self.decoder = ort.InferenceSession(
|
||||||
|
decoder_model_filename,
|
||||||
|
sess_options=self.session_opts,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||||
|
self.context_size = int(decoder_meta["context_size"])
|
||||||
|
self.vocab_size = int(decoder_meta["vocab_size"])
|
||||||
|
|
||||||
|
logging.info(f"context_size: {self.context_size}")
|
||||||
|
logging.info(f"vocab_size: {self.vocab_size}")
|
||||||
|
|
||||||
|
def init_joiner(self, joiner_model_filename: str):
|
||||||
|
self.joiner = ort.InferenceSession(
|
||||||
|
joiner_model_filename,
|
||||||
|
sess_options=self.session_opts,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||||
|
self.joiner_dim = int(joiner_meta["joiner_dim"])
|
||||||
|
|
||||||
|
logging.info(f"joiner_dim: {self.joiner_dim}")
|
||||||
|
|
||||||
|
def _build_encoder_input_output(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
) -> Tuple[Dict[str, np.ndarray], List[str]]:
|
||||||
|
encoder_input = {
|
||||||
|
"x": x.numpy(),
|
||||||
|
"state0": self.states[0],
|
||||||
|
"state1": self.states[1],
|
||||||
|
}
|
||||||
|
encoder_output = ["encoder_out", "new_state0", "new_state1"]
|
||||||
|
|
||||||
|
return encoder_input, encoder_output
|
||||||
|
|
||||||
|
def _update_states(self, states: List[np.ndarray]):
|
||||||
|
self.states = states
|
||||||
|
|
||||||
|
def run_encoder(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A 3-D tensor of shape (N, T, C)
|
||||||
|
Returns:
|
||||||
|
Return a 3-D tensor of shape (N, T', joiner_dim) where
|
||||||
|
T' is usually equal to ((T-3)//2-1)//2
|
||||||
|
"""
|
||||||
|
encoder_input, encoder_output_names = self._build_encoder_input_output(x)
|
||||||
|
out = self.encoder.run(encoder_output_names, encoder_input)
|
||||||
|
|
||||||
|
self._update_states(out[1:])
|
||||||
|
|
||||||
|
return torch.from_numpy(out[0])
|
||||||
|
|
||||||
|
def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
decoder_input:
|
||||||
|
A 2-D tensor of shape (N, context_size)
|
||||||
|
Returns:
|
||||||
|
Return a 2-D tensor of shape (N, joiner_dim)
|
||||||
|
"""
|
||||||
|
out = self.decoder.run(
|
||||||
|
[self.decoder.get_outputs()[0].name],
|
||||||
|
{self.decoder.get_inputs()[0].name: decoder_input.numpy()},
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
return torch.from_numpy(out)
|
||||||
|
|
||||||
|
def run_joiner(
|
||||||
|
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder_out:
|
||||||
|
A 2-D tensor of shape (N, joiner_dim)
|
||||||
|
decoder_out:
|
||||||
|
A 2-D tensor of shape (N, joiner_dim)
|
||||||
|
Returns:
|
||||||
|
Return a 2-D tensor of shape (N, vocab_size)
|
||||||
|
"""
|
||||||
|
out = self.joiner.run(
|
||||||
|
[self.joiner.get_outputs()[0].name],
|
||||||
|
{
|
||||||
|
self.joiner.get_inputs()[0].name: encoder_out.numpy(),
|
||||||
|
self.joiner.get_inputs()[1].name: decoder_out.numpy(),
|
||||||
|
},
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
return torch.from_numpy(out)
|
||||||
|
|
||||||
|
|
||||||
|
def read_sound_files(
|
||||||
|
filenames: List[str], expected_sample_rate: float
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
A list of sound filenames.
|
||||||
|
expected_sample_rate:
|
||||||
|
The expected sample rate of the sound files.
|
||||||
|
Returns:
|
||||||
|
Return a list of 1-D float32 torch tensors.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for f in filenames:
|
||||||
|
wave, sample_rate = torchaudio.load(f)
|
||||||
|
assert (
|
||||||
|
sample_rate == expected_sample_rate
|
||||||
|
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||||
|
# We use only the first channel
|
||||||
|
ans.append(wave[0].contiguous())
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def create_streaming_feature_extractor() -> OnlineFeature:
|
||||||
|
"""Create a CPU streaming feature extractor.
|
||||||
|
|
||||||
|
At present, we assume it returns a fbank feature extractor with
|
||||||
|
fixed options. In the future, we will support passing in the options
|
||||||
|
from outside.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a CPU streaming feature extractor.
|
||||||
|
"""
|
||||||
|
opts = FbankOptions()
|
||||||
|
opts.device = "cpu"
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
opts.frame_opts.samp_freq = 16000
|
||||||
|
opts.mel_opts.num_bins = 80
|
||||||
|
return OnlineFbank(opts)
|
||||||
|
|
||||||
|
|
||||||
|
def greedy_search(
|
||||||
|
model: OnnxModel,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
context_size: int,
|
||||||
|
decoder_out: Optional[torch.Tensor] = None,
|
||||||
|
hyp: Optional[List[int]] = None,
|
||||||
|
) -> List[int]:
|
||||||
|
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
The transducer model.
|
||||||
|
encoder_out:
|
||||||
|
A 3-D tensor of shape (1, T, joiner_dim)
|
||||||
|
context_size:
|
||||||
|
The context size of the decoder model.
|
||||||
|
decoder_out:
|
||||||
|
Optional. Decoder output of the previous chunk.
|
||||||
|
hyp:
|
||||||
|
Decoding results for previous chunks.
|
||||||
|
Returns:
|
||||||
|
Return the decoded results so far.
|
||||||
|
"""
|
||||||
|
|
||||||
|
blank_id = 0
|
||||||
|
|
||||||
|
if decoder_out is None:
|
||||||
|
assert hyp is None, hyp
|
||||||
|
hyp = [blank_id] * context_size
|
||||||
|
decoder_input = torch.tensor([hyp], dtype=torch.int64)
|
||||||
|
decoder_out = model.run_decoder(decoder_input)
|
||||||
|
else:
|
||||||
|
assert hyp is not None, hyp
|
||||||
|
|
||||||
|
encoder_out = encoder_out.squeeze(0)
|
||||||
|
T = encoder_out.size(0)
|
||||||
|
for t in range(T):
|
||||||
|
cur_encoder_out = encoder_out[t : t + 1]
|
||||||
|
joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0)
|
||||||
|
y = joiner_out.argmax(dim=0).item()
|
||||||
|
if y != blank_id:
|
||||||
|
hyp.append(y)
|
||||||
|
decoder_input = hyp[-context_size:]
|
||||||
|
decoder_input = torch.tensor([decoder_input], dtype=torch.int64)
|
||||||
|
decoder_out = model.run_decoder(decoder_input)
|
||||||
|
|
||||||
|
return hyp, decoder_out
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
|
||||||
|
model = OnnxModel(
|
||||||
|
encoder_model_filename=args.encoder_model_filename,
|
||||||
|
decoder_model_filename=args.decoder_model_filename,
|
||||||
|
joiner_model_filename=args.joiner_model_filename,
|
||||||
|
)
|
||||||
|
|
||||||
|
sample_rate = 16000
|
||||||
|
|
||||||
|
logging.info("Constructing Fbank computer")
|
||||||
|
online_fbank = create_streaming_feature_extractor()
|
||||||
|
|
||||||
|
logging.info(f"Reading sound files: {args.sound_file}")
|
||||||
|
waves = read_sound_files(
|
||||||
|
filenames=[args.sound_file],
|
||||||
|
expected_sample_rate=sample_rate,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32)
|
||||||
|
wave_samples = torch.cat([waves, tail_padding])
|
||||||
|
|
||||||
|
num_processed_frames = 0
|
||||||
|
segment = model.segment
|
||||||
|
offset = model.offset
|
||||||
|
|
||||||
|
context_size = model.context_size
|
||||||
|
hyp = None
|
||||||
|
decoder_out = None
|
||||||
|
|
||||||
|
chunk = int(1 * sample_rate) # 1 second
|
||||||
|
start = 0
|
||||||
|
while start < wave_samples.numel():
|
||||||
|
end = min(start + chunk, wave_samples.numel())
|
||||||
|
samples = wave_samples[start:end]
|
||||||
|
start += chunk
|
||||||
|
|
||||||
|
online_fbank.accept_waveform(
|
||||||
|
sampling_rate=sample_rate,
|
||||||
|
waveform=samples,
|
||||||
|
)
|
||||||
|
|
||||||
|
while online_fbank.num_frames_ready - num_processed_frames >= segment:
|
||||||
|
frames = []
|
||||||
|
for i in range(segment):
|
||||||
|
frames.append(online_fbank.get_frame(num_processed_frames + i))
|
||||||
|
num_processed_frames += offset
|
||||||
|
frames = torch.cat(frames, dim=0)
|
||||||
|
frames = frames.unsqueeze(0)
|
||||||
|
encoder_out = model.run_encoder(frames)
|
||||||
|
hyp, decoder_out = greedy_search(
|
||||||
|
model,
|
||||||
|
encoder_out,
|
||||||
|
context_size,
|
||||||
|
decoder_out,
|
||||||
|
hyp,
|
||||||
|
)
|
||||||
|
|
||||||
|
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
||||||
|
|
||||||
|
text = ""
|
||||||
|
for i in hyp[context_size:]:
|
||||||
|
text += symbol_table[i]
|
||||||
|
text = text.replace("▁", " ").strip()
|
||||||
|
|
||||||
|
logging.info(args.sound_file)
|
||||||
|
logging.info(text)
|
||||||
|
|
||||||
|
logging.info("Decoding Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
@ -16,13 +16,18 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Please see
|
||||||
|
https://k2-fsa.github.io/icefall/model-export/export-ncnn.html
|
||||||
|
for usage
|
||||||
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import k2
|
||||||
import ncnn
|
import ncnn
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||||
@ -32,9 +37,9 @@ def get_args():
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model-filename",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="Path to bpe.model",
|
help="Path to tokens.txt",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -251,9 +256,6 @@ def main():
|
|||||||
|
|
||||||
model = Model(args)
|
model = Model(args)
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
|
||||||
sp.load(args.bpe_model_filename)
|
|
||||||
|
|
||||||
sound_file = args.sound_filename
|
sound_file = args.sound_filename
|
||||||
|
|
||||||
sample_rate = 16000
|
sample_rate = 16000
|
||||||
@ -329,10 +331,16 @@ def main():
|
|||||||
model, encoder_out.squeeze(0), decoder_out, hyp
|
model, encoder_out.squeeze(0), decoder_out, hyp
|
||||||
)
|
)
|
||||||
|
|
||||||
|
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
||||||
|
|
||||||
context_size = 2
|
context_size = 2
|
||||||
|
text = ""
|
||||||
|
for i in hyp[context_size:]:
|
||||||
|
text += symbol_table[i]
|
||||||
|
text = text.replace("▁", " ").strip()
|
||||||
|
|
||||||
logging.info(sound_file)
|
logging.info(sound_file)
|
||||||
logging.info(sp.decode(hyp[context_size:]))
|
logging.info(text)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
1
egs/librispeech/ASR/lstm_transducer_stateless3/export-onnx.py
Symbolic link
1
egs/librispeech/ASR/lstm_transducer_stateless3/export-onnx.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../lstm_transducer_stateless2/export-onnx.py
|
1
egs/librispeech/ASR/lstm_transducer_stateless3/onnx_check.py
Symbolic link
1
egs/librispeech/ASR/lstm_transducer_stateless3/onnx_check.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../lstm_transducer_stateless2/onnx_check.py
|
@ -0,0 +1 @@
|
|||||||
|
../lstm_transducer_stateless2/onnx_pretrained.py
|
@ -125,13 +125,13 @@ def test_encoder(
|
|||||||
onnx_model: OnnxModel,
|
onnx_model: OnnxModel,
|
||||||
):
|
):
|
||||||
C = 80
|
C = 80
|
||||||
for i in range(10):
|
for i in range(3):
|
||||||
N = torch.randint(low=1, high=20, size=(1,)).item()
|
N = torch.randint(low=1, high=20, size=(1,)).item()
|
||||||
T = torch.randint(low=50, high=100, size=(1,)).item()
|
T = torch.randint(low=30, high=50, size=(1,)).item()
|
||||||
logging.info(f"test_encoder: iter {i}, N={N}, T={T}")
|
logging.info(f"test_encoder: iter {i}, N={N}, T={T}")
|
||||||
|
|
||||||
x = torch.rand(N, T, C)
|
x = torch.rand(N, T, C)
|
||||||
x_lens = torch.randint(low=10, high=T + 1, size=(N,))
|
x_lens = torch.randint(low=30, high=T + 1, size=(N,))
|
||||||
x_lens[0] = T
|
x_lens[0] = T
|
||||||
|
|
||||||
torch_encoder_out, torch_encoder_out_lens = torch_model.encoder(x, x_lens)
|
torch_encoder_out, torch_encoder_out_lens = torch_model.encoder(x, x_lens)
|
||||||
|
560
egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py
Executable file
560
egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py
Executable file
@ -0,0 +1,560 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script exports a transducer model from PyTorch to ONNX.
|
||||||
|
|
||||||
|
We use the pre-trained model from
|
||||||
|
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
|
||||||
|
as an example to show how to use this file.
|
||||||
|
|
||||||
|
1. Download the pre-trained model
|
||||||
|
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||||
|
git lfs pull --include "exp/pretrained-epoch-30-avg-9.pt"
|
||||||
|
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained-epoch-30-avg-9.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
2. Export the model to ONNX
|
||||||
|
|
||||||
|
./pruned_transducer_stateless7/export-onnx.py \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--feedforward-dims "1024,1024,2048,2048,1024"
|
||||||
|
|
||||||
|
It will generate the following 3 files inside $repo/exp:
|
||||||
|
|
||||||
|
- encoder-epoch-99-avg-1.onnx
|
||||||
|
- decoder-epoch-99-avg-1.onnx
|
||||||
|
- joiner-epoch-99-avg-1.onnx
|
||||||
|
|
||||||
|
See ./onnx_pretrained.py and ./onnx_check.py for how to
|
||||||
|
use the exported ONNX models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from decoder import Decoder
|
||||||
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
from zipformer import Zipformer
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.utils import setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=28,
|
||||||
|
help="""It specifies the checkpoint to use for averaging.
|
||||||
|
Note: Epoch counts from 0.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="pruned_transducer_stateless5/exp",
|
||||||
|
help="""It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-model",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_bpe_500/bpe.model",
|
||||||
|
help="Path to the BPE model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||||
|
"""Add meta data to an ONNX model. It is changed in-place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Filename of the ONNX model to be changed.
|
||||||
|
meta_data:
|
||||||
|
Key-value pairs.
|
||||||
|
"""
|
||||||
|
model = onnx.load(filename)
|
||||||
|
for key, value in meta_data.items():
|
||||||
|
meta = model.metadata_props.add()
|
||||||
|
meta.key = key
|
||||||
|
meta.value = value
|
||||||
|
|
||||||
|
onnx.save(model, filename)
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxEncoder(nn.Module):
|
||||||
|
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
|
||||||
|
|
||||||
|
def __init__(self, encoder: Zipformer, encoder_proj: nn.Linear):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder:
|
||||||
|
A Zipformer encoder.
|
||||||
|
encoder_proj:
|
||||||
|
The projection layer for encoder from the joiner.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = encoder
|
||||||
|
self.encoder_proj = encoder_proj
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_lens: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Please see the help information of Zipformer.forward
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A 3-D tensor of shape (N, T, C)
|
||||||
|
x_lens:
|
||||||
|
A 1-D tensor of shape (N,). Its dtype is torch.int64
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing:
|
||||||
|
- encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
|
||||||
|
- encoder_out_lens, A 1-D tensor of shape (N,)
|
||||||
|
"""
|
||||||
|
encoder_out, encoder_out_lens = self.encoder(x, x_lens)
|
||||||
|
|
||||||
|
encoder_out = self.encoder_proj(encoder_out)
|
||||||
|
# Now encoder_out is of shape (N, T, joiner_dim)
|
||||||
|
|
||||||
|
return encoder_out, encoder_out_lens
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxDecoder(nn.Module):
|
||||||
|
"""A wrapper for Decoder and the decoder_proj from the joiner"""
|
||||||
|
|
||||||
|
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
|
||||||
|
super().__init__()
|
||||||
|
self.decoder = decoder
|
||||||
|
self.decoder_proj = decoder_proj
|
||||||
|
|
||||||
|
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
y:
|
||||||
|
A 2-D tensor of shape (N, context_size).
|
||||||
|
Returns
|
||||||
|
Return a 2-D tensor of shape (N, joiner_dim)
|
||||||
|
"""
|
||||||
|
need_pad = False
|
||||||
|
decoder_output = self.decoder(y, need_pad=need_pad)
|
||||||
|
decoder_output = decoder_output.squeeze(1)
|
||||||
|
output = self.decoder_proj(decoder_output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxJoiner(nn.Module):
|
||||||
|
"""A wrapper for the joiner"""
|
||||||
|
|
||||||
|
def __init__(self, output_linear: nn.Linear):
|
||||||
|
super().__init__()
|
||||||
|
self.output_linear = output_linear
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
decoder_out: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder_out:
|
||||||
|
A 2-D tensor of shape (N, joiner_dim)
|
||||||
|
decoder_out:
|
||||||
|
A 2-D tensor of shape (N, joiner_dim)
|
||||||
|
Returns:
|
||||||
|
Return a 2-D tensor of shape (N, vocab_size)
|
||||||
|
"""
|
||||||
|
logit = encoder_out + decoder_out
|
||||||
|
logit = self.output_linear(torch.tanh(logit))
|
||||||
|
return logit
|
||||||
|
|
||||||
|
|
||||||
|
def export_encoder_model_onnx(
|
||||||
|
encoder_model: OnnxEncoder,
|
||||||
|
encoder_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given encoder model to ONNX format.
|
||||||
|
The exported model has two inputs:
|
||||||
|
|
||||||
|
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
||||||
|
- x_lens, a tensor of shape (N,); dtype is torch.int64
|
||||||
|
|
||||||
|
and it has two outputs:
|
||||||
|
|
||||||
|
- encoder_out, a tensor of shape (N, T', joiner_dim)
|
||||||
|
- encoder_out_lens, a tensor of shape (N,)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_model:
|
||||||
|
The input encoder model
|
||||||
|
encoder_filename:
|
||||||
|
The filename to save the exported ONNX model.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||||
|
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
encoder_model,
|
||||||
|
(x, x_lens),
|
||||||
|
encoder_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["x", "x_lens"],
|
||||||
|
output_names=["encoder_out", "encoder_out_lens"],
|
||||||
|
dynamic_axes={
|
||||||
|
"x": {0: "N", 1: "T"},
|
||||||
|
"x_lens": {0: "N"},
|
||||||
|
"encoder_out": {0: "N", 1: "T"},
|
||||||
|
"encoder_out_lens": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def export_decoder_model_onnx(
|
||||||
|
decoder_model: OnnxDecoder,
|
||||||
|
decoder_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the decoder model to ONNX format.
|
||||||
|
|
||||||
|
The exported model has one input:
|
||||||
|
|
||||||
|
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||||
|
|
||||||
|
and has one output:
|
||||||
|
|
||||||
|
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder_model:
|
||||||
|
The decoder model to be exported.
|
||||||
|
decoder_filename:
|
||||||
|
Filename to save the exported ONNX model.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
context_size = decoder_model.decoder.context_size
|
||||||
|
vocab_size = decoder_model.decoder.vocab_size
|
||||||
|
|
||||||
|
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||||
|
torch.onnx.export(
|
||||||
|
decoder_model,
|
||||||
|
y,
|
||||||
|
decoder_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["y"],
|
||||||
|
output_names=["decoder_out"],
|
||||||
|
dynamic_axes={
|
||||||
|
"y": {0: "N"},
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
meta_data = {
|
||||||
|
"context_size": str(context_size),
|
||||||
|
"vocab_size": str(vocab_size),
|
||||||
|
}
|
||||||
|
add_meta_data(filename=decoder_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
def export_joiner_model_onnx(
|
||||||
|
joiner_model: nn.Module,
|
||||||
|
joiner_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the joiner model to ONNX format.
|
||||||
|
The exported joiner model has two inputs:
|
||||||
|
|
||||||
|
- encoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
- decoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
|
||||||
|
and produces one output:
|
||||||
|
|
||||||
|
- logit: a tensor of shape (N, vocab_size)
|
||||||
|
"""
|
||||||
|
joiner_dim = joiner_model.output_linear.weight.shape[1]
|
||||||
|
logging.info(f"joiner dim: {joiner_dim}")
|
||||||
|
|
||||||
|
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||||
|
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
joiner_model,
|
||||||
|
(projected_encoder_out, projected_decoder_out),
|
||||||
|
joiner_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=[
|
||||||
|
"encoder_out",
|
||||||
|
"decoder_out",
|
||||||
|
],
|
||||||
|
output_names=["logit"],
|
||||||
|
dynamic_axes={
|
||||||
|
"encoder_out": {0: "N"},
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
"logit": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
meta_data = {
|
||||||
|
"joiner_dim": str(joiner_dim),
|
||||||
|
}
|
||||||
|
add_meta_data(filename=joiner_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
setup_logger(f"{params.exp_dir}/log-export/log-export-onnx")
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to("cpu")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
|
|
||||||
|
encoder = OnnxEncoder(
|
||||||
|
encoder=model.encoder,
|
||||||
|
encoder_proj=model.joiner.encoder_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder = OnnxDecoder(
|
||||||
|
decoder=model.decoder,
|
||||||
|
decoder_proj=model.joiner.decoder_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
|
||||||
|
|
||||||
|
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
|
||||||
|
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
|
||||||
|
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
|
||||||
|
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
|
||||||
|
logging.info(f"encoder parameters: {encoder_num_param}")
|
||||||
|
logging.info(f"decoder parameters: {decoder_num_param}")
|
||||||
|
logging.info(f"joiner parameters: {joiner_num_param}")
|
||||||
|
logging.info(f"total parameters: {total_num_param}")
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
suffix = f"iter-{params.iter}"
|
||||||
|
else:
|
||||||
|
suffix = f"epoch-{params.epoch}"
|
||||||
|
|
||||||
|
suffix += f"-avg-{params.avg}"
|
||||||
|
|
||||||
|
opset_version = 13
|
||||||
|
|
||||||
|
logging.info("Exporting encoder")
|
||||||
|
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
|
||||||
|
export_encoder_model_onnx(
|
||||||
|
encoder,
|
||||||
|
encoder_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported encoder to {encoder_filename}")
|
||||||
|
|
||||||
|
logging.info("Exporting decoder")
|
||||||
|
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
|
||||||
|
export_decoder_model_onnx(
|
||||||
|
decoder,
|
||||||
|
decoder_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported decoder to {decoder_filename}")
|
||||||
|
|
||||||
|
logging.info("Exporting joiner")
|
||||||
|
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
|
||||||
|
export_joiner_model_onnx(
|
||||||
|
joiner,
|
||||||
|
joiner_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported joiner to {joiner_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
main()
|
@ -41,31 +41,7 @@ Check
|
|||||||
https://github.com/k2-fsa/sherpa
|
https://github.com/k2-fsa/sherpa
|
||||||
for how to use the exported models outside of icefall.
|
for how to use the exported models outside of icefall.
|
||||||
|
|
||||||
(2) Export to ONNX format
|
(2) Export `model.state_dict()`
|
||||||
|
|
||||||
./pruned_transducer_stateless7/export.py \
|
|
||||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
|
||||||
--epoch 20 \
|
|
||||||
--avg 10 \
|
|
||||||
--onnx 1
|
|
||||||
|
|
||||||
It will generate the following files in the given `exp_dir`.
|
|
||||||
Check `onnx_check.py` for how to use them.
|
|
||||||
|
|
||||||
- encoder.onnx
|
|
||||||
- decoder.onnx
|
|
||||||
- joiner.onnx
|
|
||||||
- joiner_encoder_proj.onnx
|
|
||||||
- joiner_decoder_proj.onnx
|
|
||||||
|
|
||||||
Please see ./onnx_pretrained.py for usage of the generated files
|
|
||||||
|
|
||||||
Check
|
|
||||||
https://github.com/k2-fsa/sherpa-onnx
|
|
||||||
for how to use the exported models outside of icefall.
|
|
||||||
|
|
||||||
(3) Export `model.state_dict()`
|
|
||||||
|
|
||||||
./pruned_transducer_stateless7/export.py \
|
./pruned_transducer_stateless7/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||||
@ -196,23 +172,6 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--onnx",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""If True, --jit is ignored and it exports the model
|
|
||||||
to onnx format. It will generate the following files:
|
|
||||||
|
|
||||||
- encoder.onnx
|
|
||||||
- decoder.onnx
|
|
||||||
- joiner.onnx
|
|
||||||
- joiner_encoder_proj.onnx
|
|
||||||
- joiner_decoder_proj.onnx
|
|
||||||
|
|
||||||
Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -225,204 +184,6 @@ def get_parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def export_encoder_model_onnx(
|
|
||||||
encoder_model: nn.Module,
|
|
||||||
encoder_filename: str,
|
|
||||||
opset_version: int = 11,
|
|
||||||
) -> None:
|
|
||||||
"""Export the given encoder model to ONNX format.
|
|
||||||
The exported model has two inputs:
|
|
||||||
|
|
||||||
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
|
||||||
- x_lens, a tensor of shape (N,); dtype is torch.int64
|
|
||||||
|
|
||||||
and it has two outputs:
|
|
||||||
|
|
||||||
- encoder_out, a tensor of shape (N, T, C)
|
|
||||||
- encoder_out_lens, a tensor of shape (N,)
|
|
||||||
|
|
||||||
Note: The warmup argument is fixed to 1.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
encoder_model:
|
|
||||||
The input encoder model
|
|
||||||
encoder_filename:
|
|
||||||
The filename to save the exported ONNX model.
|
|
||||||
opset_version:
|
|
||||||
The opset version to use.
|
|
||||||
"""
|
|
||||||
x = torch.zeros(1, 101, 80, dtype=torch.float32)
|
|
||||||
x_lens = torch.tensor([101], dtype=torch.int64)
|
|
||||||
|
|
||||||
# encoder_model = torch.jit.script(encoder_model)
|
|
||||||
# It throws the following error for the above statement
|
|
||||||
#
|
|
||||||
# RuntimeError: Exporting the operator __is_ to ONNX opset version
|
|
||||||
# 11 is not supported. Please feel free to request support or
|
|
||||||
# submit a pull request on PyTorch GitHub.
|
|
||||||
#
|
|
||||||
# I cannot find which statement causes the above error.
|
|
||||||
# torch.onnx.export() will use torch.jit.trace() internally, which
|
|
||||||
# works well for the current reworked model
|
|
||||||
torch.onnx.export(
|
|
||||||
encoder_model,
|
|
||||||
(x, x_lens),
|
|
||||||
encoder_filename,
|
|
||||||
verbose=False,
|
|
||||||
opset_version=opset_version,
|
|
||||||
input_names=["x", "x_lens"],
|
|
||||||
output_names=["encoder_out", "encoder_out_lens"],
|
|
||||||
dynamic_axes={
|
|
||||||
"x": {0: "N", 1: "T"},
|
|
||||||
"x_lens": {0: "N"},
|
|
||||||
"encoder_out": {0: "N", 1: "T"},
|
|
||||||
"encoder_out_lens": {0: "N"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
logging.info(f"Saved to {encoder_filename}")
|
|
||||||
|
|
||||||
|
|
||||||
def export_decoder_model_onnx(
|
|
||||||
decoder_model: nn.Module,
|
|
||||||
decoder_filename: str,
|
|
||||||
opset_version: int = 11,
|
|
||||||
) -> None:
|
|
||||||
"""Export the decoder model to ONNX format.
|
|
||||||
|
|
||||||
The exported model has one input:
|
|
||||||
|
|
||||||
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
|
||||||
|
|
||||||
and has one output:
|
|
||||||
|
|
||||||
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
|
|
||||||
|
|
||||||
Note: The argument need_pad is fixed to False.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
decoder_model:
|
|
||||||
The decoder model to be exported.
|
|
||||||
decoder_filename:
|
|
||||||
Filename to save the exported ONNX model.
|
|
||||||
opset_version:
|
|
||||||
The opset version to use.
|
|
||||||
"""
|
|
||||||
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
|
||||||
need_pad = False # Always False, so we can use torch.jit.trace() here
|
|
||||||
# Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script()
|
|
||||||
# in this case
|
|
||||||
torch.onnx.export(
|
|
||||||
decoder_model,
|
|
||||||
(y, need_pad),
|
|
||||||
decoder_filename,
|
|
||||||
verbose=False,
|
|
||||||
opset_version=opset_version,
|
|
||||||
input_names=["y", "need_pad"],
|
|
||||||
output_names=["decoder_out"],
|
|
||||||
dynamic_axes={
|
|
||||||
"y": {0: "N"},
|
|
||||||
"decoder_out": {0: "N"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
logging.info(f"Saved to {decoder_filename}")
|
|
||||||
|
|
||||||
|
|
||||||
def export_joiner_model_onnx(
|
|
||||||
joiner_model: nn.Module,
|
|
||||||
joiner_filename: str,
|
|
||||||
opset_version: int = 11,
|
|
||||||
) -> None:
|
|
||||||
"""Export the joiner model to ONNX format.
|
|
||||||
The exported joiner model has two inputs:
|
|
||||||
|
|
||||||
- projected_encoder_out: a tensor of shape (N, joiner_dim)
|
|
||||||
- projected_decoder_out: a tensor of shape (N, joiner_dim)
|
|
||||||
|
|
||||||
and produces one output:
|
|
||||||
|
|
||||||
- logit: a tensor of shape (N, vocab_size)
|
|
||||||
|
|
||||||
The exported encoder_proj model has one input:
|
|
||||||
|
|
||||||
- encoder_out: a tensor of shape (N, encoder_out_dim)
|
|
||||||
|
|
||||||
and produces one output:
|
|
||||||
|
|
||||||
- projected_encoder_out: a tensor of shape (N, joiner_dim)
|
|
||||||
|
|
||||||
The exported decoder_proj model has one input:
|
|
||||||
|
|
||||||
- decoder_out: a tensor of shape (N, decoder_out_dim)
|
|
||||||
|
|
||||||
and produces one output:
|
|
||||||
|
|
||||||
- projected_decoder_out: a tensor of shape (N, joiner_dim)
|
|
||||||
"""
|
|
||||||
encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
|
|
||||||
decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
|
|
||||||
|
|
||||||
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
|
||||||
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
|
||||||
joiner_dim = joiner_model.decoder_proj.weight.shape[0]
|
|
||||||
|
|
||||||
projected_encoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32)
|
|
||||||
projected_decoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32)
|
|
||||||
|
|
||||||
project_input = False
|
|
||||||
# Note: It uses torch.jit.trace() internally
|
|
||||||
torch.onnx.export(
|
|
||||||
joiner_model,
|
|
||||||
(projected_encoder_out, projected_decoder_out, project_input),
|
|
||||||
joiner_filename,
|
|
||||||
verbose=False,
|
|
||||||
opset_version=opset_version,
|
|
||||||
input_names=[
|
|
||||||
"encoder_out",
|
|
||||||
"decoder_out",
|
|
||||||
"project_input",
|
|
||||||
],
|
|
||||||
output_names=["logit"],
|
|
||||||
dynamic_axes={
|
|
||||||
"encoder_out": {0: "N"},
|
|
||||||
"decoder_out": {0: "N"},
|
|
||||||
"logit": {0: "N"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
logging.info(f"Saved to {joiner_filename}")
|
|
||||||
|
|
||||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
|
||||||
torch.onnx.export(
|
|
||||||
joiner_model.encoder_proj,
|
|
||||||
encoder_out,
|
|
||||||
encoder_proj_filename,
|
|
||||||
verbose=False,
|
|
||||||
opset_version=opset_version,
|
|
||||||
input_names=["encoder_out"],
|
|
||||||
output_names=["projected_encoder_out"],
|
|
||||||
dynamic_axes={
|
|
||||||
"encoder_out": {0: "N"},
|
|
||||||
"projected_encoder_out": {0: "N"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
logging.info(f"Saved to {encoder_proj_filename}")
|
|
||||||
|
|
||||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
|
||||||
torch.onnx.export(
|
|
||||||
joiner_model.decoder_proj,
|
|
||||||
decoder_out,
|
|
||||||
decoder_proj_filename,
|
|
||||||
verbose=False,
|
|
||||||
opset_version=opset_version,
|
|
||||||
input_names=["decoder_out"],
|
|
||||||
output_names=["projected_decoder_out"],
|
|
||||||
dynamic_axes={
|
|
||||||
"decoder_out": {0: "N"},
|
|
||||||
"projected_decoder_out": {0: "N"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
logging.info(f"Saved to {decoder_proj_filename}")
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
@ -531,31 +292,7 @@ def main():
|
|||||||
model.to("cpu")
|
model.to("cpu")
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if params.onnx is True:
|
if params.jit is True:
|
||||||
convert_scaled_to_non_scaled(model, inplace=True)
|
|
||||||
opset_version = 13
|
|
||||||
logging.info("Exporting to onnx format")
|
|
||||||
encoder_filename = params.exp_dir / "encoder.onnx"
|
|
||||||
export_encoder_model_onnx(
|
|
||||||
model.encoder,
|
|
||||||
encoder_filename,
|
|
||||||
opset_version=opset_version,
|
|
||||||
)
|
|
||||||
|
|
||||||
decoder_filename = params.exp_dir / "decoder.onnx"
|
|
||||||
export_decoder_model_onnx(
|
|
||||||
model.decoder,
|
|
||||||
decoder_filename,
|
|
||||||
opset_version=opset_version,
|
|
||||||
)
|
|
||||||
|
|
||||||
joiner_filename = params.exp_dir / "joiner.onnx"
|
|
||||||
export_joiner_model_onnx(
|
|
||||||
model.joiner,
|
|
||||||
joiner_filename,
|
|
||||||
opset_version=opset_version,
|
|
||||||
)
|
|
||||||
elif params.jit is True:
|
|
||||||
convert_scaled_to_non_scaled(model, inplace=True)
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
# We won't use the forward() method of the model in C++, so just ignore
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
# it here.
|
# it here.
|
||||||
|
@ -1,286 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
#
|
|
||||||
# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""
|
|
||||||
This script checks that exported onnx models produce the same output
|
|
||||||
with the given torchscript model for the same input.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import onnxruntime as ort
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from icefall import is_module_available
|
|
||||||
|
|
||||||
if not is_module_available("onnxruntime"):
|
|
||||||
raise ValueError("Please 'pip install onnxruntime' first.")
|
|
||||||
|
|
||||||
|
|
||||||
ort.set_default_logger_severity(3)
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--jit-filename",
|
|
||||||
required=True,
|
|
||||||
type=str,
|
|
||||||
help="Path to the torchscript model",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--onnx-encoder-filename",
|
|
||||||
required=True,
|
|
||||||
type=str,
|
|
||||||
help="Path to the onnx encoder model",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--onnx-decoder-filename",
|
|
||||||
required=True,
|
|
||||||
type=str,
|
|
||||||
help="Path to the onnx decoder model",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--onnx-joiner-filename",
|
|
||||||
required=True,
|
|
||||||
type=str,
|
|
||||||
help="Path to the onnx joiner model",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--onnx-joiner-encoder-proj-filename",
|
|
||||||
required=True,
|
|
||||||
type=str,
|
|
||||||
help="Path to the onnx joiner encoder projection model",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--onnx-joiner-decoder-proj-filename",
|
|
||||||
required=True,
|
|
||||||
type=str,
|
|
||||||
help="Path to the onnx joiner decoder projection model",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def test_encoder(
|
|
||||||
model: torch.jit.ScriptModule,
|
|
||||||
encoder_session: ort.InferenceSession,
|
|
||||||
):
|
|
||||||
inputs = encoder_session.get_inputs()
|
|
||||||
outputs = encoder_session.get_outputs()
|
|
||||||
input_names = [n.name for n in inputs]
|
|
||||||
output_names = [n.name for n in outputs]
|
|
||||||
|
|
||||||
assert inputs[0].shape == ["N", "T", 80]
|
|
||||||
assert inputs[1].shape == ["N"]
|
|
||||||
|
|
||||||
for N in [1, 5]:
|
|
||||||
for T in [12, 50]:
|
|
||||||
print("N, T", N, T)
|
|
||||||
x = torch.rand(N, T, 80, dtype=torch.float32)
|
|
||||||
x_lens = torch.randint(low=10, high=T + 1, size=(N,))
|
|
||||||
x_lens[0] = T
|
|
||||||
|
|
||||||
encoder_inputs = {
|
|
||||||
input_names[0]: x.numpy(),
|
|
||||||
input_names[1]: x_lens.numpy(),
|
|
||||||
}
|
|
||||||
|
|
||||||
torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens)
|
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = encoder_session.run(
|
|
||||||
output_names,
|
|
||||||
encoder_inputs,
|
|
||||||
)
|
|
||||||
|
|
||||||
torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens)
|
|
||||||
|
|
||||||
encoder_out = torch.from_numpy(encoder_out)
|
|
||||||
assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), (
|
|
||||||
(encoder_out - torch_encoder_out).abs().max(),
|
|
||||||
encoder_out.shape,
|
|
||||||
torch_encoder_out.shape,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_decoder(
|
|
||||||
model: torch.jit.ScriptModule,
|
|
||||||
decoder_session: ort.InferenceSession,
|
|
||||||
):
|
|
||||||
inputs = decoder_session.get_inputs()
|
|
||||||
outputs = decoder_session.get_outputs()
|
|
||||||
input_names = [n.name for n in inputs]
|
|
||||||
output_names = [n.name for n in outputs]
|
|
||||||
|
|
||||||
assert inputs[0].shape == ["N", 2]
|
|
||||||
for N in [1, 5, 10]:
|
|
||||||
y = torch.randint(low=1, high=500, size=(10, 2))
|
|
||||||
|
|
||||||
decoder_inputs = {input_names[0]: y.numpy()}
|
|
||||||
decoder_out = decoder_session.run(
|
|
||||||
output_names,
|
|
||||||
decoder_inputs,
|
|
||||||
)[0]
|
|
||||||
decoder_out = torch.from_numpy(decoder_out)
|
|
||||||
|
|
||||||
torch_decoder_out = model.decoder(y, need_pad=False)
|
|
||||||
assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), (
|
|
||||||
(decoder_out - torch_decoder_out).abs().max()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_joiner(
|
|
||||||
model: torch.jit.ScriptModule,
|
|
||||||
joiner_session: ort.InferenceSession,
|
|
||||||
joiner_encoder_proj_session: ort.InferenceSession,
|
|
||||||
joiner_decoder_proj_session: ort.InferenceSession,
|
|
||||||
):
|
|
||||||
joiner_inputs = joiner_session.get_inputs()
|
|
||||||
joiner_outputs = joiner_session.get_outputs()
|
|
||||||
joiner_input_names = [n.name for n in joiner_inputs]
|
|
||||||
joiner_output_names = [n.name for n in joiner_outputs]
|
|
||||||
|
|
||||||
assert joiner_inputs[0].shape == ["N", 1, 1, 512]
|
|
||||||
assert joiner_inputs[1].shape == ["N", 1, 1, 512]
|
|
||||||
|
|
||||||
joiner_encoder_proj_inputs = joiner_encoder_proj_session.get_inputs()
|
|
||||||
encoder_proj_input_name = joiner_encoder_proj_inputs[0].name
|
|
||||||
|
|
||||||
assert joiner_encoder_proj_inputs[0].shape == ["N", 384]
|
|
||||||
|
|
||||||
joiner_encoder_proj_outputs = joiner_encoder_proj_session.get_outputs()
|
|
||||||
encoder_proj_output_name = joiner_encoder_proj_outputs[0].name
|
|
||||||
|
|
||||||
joiner_decoder_proj_inputs = joiner_decoder_proj_session.get_inputs()
|
|
||||||
decoder_proj_input_name = joiner_decoder_proj_inputs[0].name
|
|
||||||
|
|
||||||
assert joiner_decoder_proj_inputs[0].shape == ["N", 512]
|
|
||||||
|
|
||||||
joiner_decoder_proj_outputs = joiner_decoder_proj_session.get_outputs()
|
|
||||||
decoder_proj_output_name = joiner_decoder_proj_outputs[0].name
|
|
||||||
|
|
||||||
for N in [1, 5, 10]:
|
|
||||||
encoder_out = torch.rand(N, 384)
|
|
||||||
decoder_out = torch.rand(N, 512)
|
|
||||||
|
|
||||||
projected_encoder_out = torch.rand(N, 1, 1, 512)
|
|
||||||
projected_decoder_out = torch.rand(N, 1, 1, 512)
|
|
||||||
|
|
||||||
joiner_inputs = {
|
|
||||||
joiner_input_names[0]: projected_encoder_out.numpy(),
|
|
||||||
joiner_input_names[1]: projected_decoder_out.numpy(),
|
|
||||||
}
|
|
||||||
joiner_out = joiner_session.run(joiner_output_names, joiner_inputs)[0]
|
|
||||||
joiner_out = torch.from_numpy(joiner_out)
|
|
||||||
|
|
||||||
torch_joiner_out = model.joiner(
|
|
||||||
projected_encoder_out,
|
|
||||||
projected_decoder_out,
|
|
||||||
project_input=False,
|
|
||||||
)
|
|
||||||
assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), (
|
|
||||||
(joiner_out - torch_joiner_out).abs().max()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Now test encoder_proj
|
|
||||||
joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()}
|
|
||||||
joiner_encoder_proj_out = joiner_encoder_proj_session.run(
|
|
||||||
[encoder_proj_output_name], joiner_encoder_proj_inputs
|
|
||||||
)[0]
|
|
||||||
joiner_encoder_proj_out = torch.from_numpy(joiner_encoder_proj_out)
|
|
||||||
|
|
||||||
torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out)
|
|
||||||
assert torch.allclose(
|
|
||||||
joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5
|
|
||||||
), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max())
|
|
||||||
|
|
||||||
# Now test decoder_proj
|
|
||||||
joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()}
|
|
||||||
joiner_decoder_proj_out = joiner_decoder_proj_session.run(
|
|
||||||
[decoder_proj_output_name], joiner_decoder_proj_inputs
|
|
||||||
)[0]
|
|
||||||
joiner_decoder_proj_out = torch.from_numpy(joiner_decoder_proj_out)
|
|
||||||
|
|
||||||
torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out)
|
|
||||||
assert torch.allclose(
|
|
||||||
joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5
|
|
||||||
), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max())
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def main():
|
|
||||||
args = get_parser().parse_args()
|
|
||||||
logging.info(vars(args))
|
|
||||||
|
|
||||||
model = torch.jit.load(args.jit_filename)
|
|
||||||
|
|
||||||
options = ort.SessionOptions()
|
|
||||||
options.inter_op_num_threads = 1
|
|
||||||
options.intra_op_num_threads = 1
|
|
||||||
|
|
||||||
logging.info("Test encoder")
|
|
||||||
encoder_session = ort.InferenceSession(
|
|
||||||
args.onnx_encoder_filename,
|
|
||||||
sess_options=options,
|
|
||||||
)
|
|
||||||
test_encoder(model, encoder_session)
|
|
||||||
|
|
||||||
logging.info("Test decoder")
|
|
||||||
decoder_session = ort.InferenceSession(
|
|
||||||
args.onnx_decoder_filename,
|
|
||||||
sess_options=options,
|
|
||||||
)
|
|
||||||
test_decoder(model, decoder_session)
|
|
||||||
|
|
||||||
logging.info("Test joiner")
|
|
||||||
joiner_session = ort.InferenceSession(
|
|
||||||
args.onnx_joiner_filename,
|
|
||||||
sess_options=options,
|
|
||||||
)
|
|
||||||
joiner_encoder_proj_session = ort.InferenceSession(
|
|
||||||
args.onnx_joiner_encoder_proj_filename,
|
|
||||||
sess_options=options,
|
|
||||||
)
|
|
||||||
joiner_decoder_proj_session = ort.InferenceSession(
|
|
||||||
args.onnx_joiner_decoder_proj_filename,
|
|
||||||
sess_options=options,
|
|
||||||
)
|
|
||||||
test_joiner(
|
|
||||||
model,
|
|
||||||
joiner_session,
|
|
||||||
joiner_encoder_proj_session,
|
|
||||||
joiner_decoder_proj_session,
|
|
||||||
)
|
|
||||||
logging.info("Finished checking ONNX models")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
torch.manual_seed(20220727)
|
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
|
||||||
main()
|
|
1
egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py
Symbolic link
1
egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../pruned_transducer_stateless5/onnx_check.py
|
@ -1,388 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""
|
|
||||||
This script loads ONNX models and uses them to decode waves.
|
|
||||||
You can use the following command to get the exported models:
|
|
||||||
|
|
||||||
./pruned_transducer_stateless7/export.py \
|
|
||||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
|
||||||
--epoch 20 \
|
|
||||||
--avg 10 \
|
|
||||||
--onnx 1
|
|
||||||
|
|
||||||
Usage of this script:
|
|
||||||
|
|
||||||
./pruned_transducer_stateless7/onnx_pretrained.py \
|
|
||||||
--encoder-model-filename ./pruned_transducer_stateless7/exp/encoder.onnx \
|
|
||||||
--decoder-model-filename ./pruned_transducer_stateless7/exp/decoder.onnx \
|
|
||||||
--joiner-model-filename ./pruned_transducer_stateless7/exp/joiner.onnx \
|
|
||||||
--joiner-encoder-proj-model-filename ./pruned_transducer_stateless7/exp/joiner_encoder_proj.onnx \
|
|
||||||
--joiner-decoder-proj-model-filename ./pruned_transducer_stateless7/exp/joiner_decoder_proj.onnx \
|
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
|
||||||
/path/to/foo.wav \
|
|
||||||
/path/to/bar.wav
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import kaldifeat
|
|
||||||
import numpy as np
|
|
||||||
import onnxruntime as ort
|
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
|
||||||
import torchaudio
|
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--encoder-model-filename",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path to the encoder onnx model. ",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--decoder-model-filename",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path to the decoder onnx model. ",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--joiner-model-filename",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path to the joiner onnx model. ",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--joiner-encoder-proj-model-filename",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path to the joiner encoder_proj onnx model. ",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--joiner-decoder-proj-model-filename",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path to the joiner decoder_proj onnx model. ",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--bpe-model",
|
|
||||||
type=str,
|
|
||||||
help="""Path to bpe.model.""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"sound_files",
|
|
||||||
type=str,
|
|
||||||
nargs="+",
|
|
||||||
help="The input sound file(s) to transcribe. "
|
|
||||||
"Supported formats are those supported by torchaudio.load(). "
|
|
||||||
"For example, wav and flac are supported. "
|
|
||||||
"The sample rate has to be 16kHz.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--sample-rate",
|
|
||||||
type=int,
|
|
||||||
default=16000,
|
|
||||||
help="The sample rate of the input sound file",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--context-size",
|
|
||||||
type=int,
|
|
||||||
default=2,
|
|
||||||
help="Context size of the decoder model",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def read_sound_files(
|
|
||||||
filenames: List[str], expected_sample_rate: float
|
|
||||||
) -> List[torch.Tensor]:
|
|
||||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
|
||||||
Args:
|
|
||||||
filenames:
|
|
||||||
A list of sound filenames.
|
|
||||||
expected_sample_rate:
|
|
||||||
The expected sample rate of the sound files.
|
|
||||||
Returns:
|
|
||||||
Return a list of 1-D float32 torch tensors.
|
|
||||||
"""
|
|
||||||
ans = []
|
|
||||||
for f in filenames:
|
|
||||||
wave, sample_rate = torchaudio.load(f)
|
|
||||||
assert (
|
|
||||||
sample_rate == expected_sample_rate
|
|
||||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
|
||||||
# We use only the first channel
|
|
||||||
ans.append(wave[0])
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
def greedy_search(
|
|
||||||
decoder: ort.InferenceSession,
|
|
||||||
joiner: ort.InferenceSession,
|
|
||||||
joiner_encoder_proj: ort.InferenceSession,
|
|
||||||
joiner_decoder_proj: ort.InferenceSession,
|
|
||||||
encoder_out: np.ndarray,
|
|
||||||
encoder_out_lens: np.ndarray,
|
|
||||||
context_size: int,
|
|
||||||
) -> List[List[int]]:
|
|
||||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
|
||||||
Args:
|
|
||||||
decoder:
|
|
||||||
The decoder model.
|
|
||||||
joiner:
|
|
||||||
The joiner model.
|
|
||||||
joiner_encoder_proj:
|
|
||||||
The joiner encoder projection model.
|
|
||||||
joiner_decoder_proj:
|
|
||||||
The joiner decoder projection model.
|
|
||||||
encoder_out:
|
|
||||||
A 3-D tensor of shape (N, T, C)
|
|
||||||
encoder_out_lens:
|
|
||||||
A 1-D tensor of shape (N,).
|
|
||||||
context_size:
|
|
||||||
The context size of the decoder model.
|
|
||||||
Returns:
|
|
||||||
Return the decoded results for each utterance.
|
|
||||||
"""
|
|
||||||
encoder_out = torch.from_numpy(encoder_out)
|
|
||||||
encoder_out_lens = torch.from_numpy(encoder_out_lens)
|
|
||||||
assert encoder_out.ndim == 3
|
|
||||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
|
||||||
|
|
||||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
|
||||||
input=encoder_out,
|
|
||||||
lengths=encoder_out_lens.cpu(),
|
|
||||||
batch_first=True,
|
|
||||||
enforce_sorted=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
projected_encoder_out = joiner_encoder_proj.run(
|
|
||||||
[joiner_encoder_proj.get_outputs()[0].name],
|
|
||||||
{joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()},
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
blank_id = 0 # hard-code to 0
|
|
||||||
|
|
||||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
|
||||||
N = encoder_out.size(0)
|
|
||||||
|
|
||||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
|
||||||
assert N == batch_size_list[0], (N, batch_size_list)
|
|
||||||
|
|
||||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
|
||||||
|
|
||||||
decoder_input_nodes = decoder.get_inputs()
|
|
||||||
decoder_output_nodes = decoder.get_outputs()
|
|
||||||
|
|
||||||
joiner_input_nodes = joiner.get_inputs()
|
|
||||||
joiner_output_nodes = joiner.get_outputs()
|
|
||||||
|
|
||||||
decoder_input = torch.tensor(
|
|
||||||
hyps,
|
|
||||||
dtype=torch.int64,
|
|
||||||
) # (N, context_size)
|
|
||||||
|
|
||||||
decoder_out = decoder.run(
|
|
||||||
[decoder_output_nodes[0].name],
|
|
||||||
{
|
|
||||||
decoder_input_nodes[0].name: decoder_input.numpy(),
|
|
||||||
},
|
|
||||||
)[0].squeeze(1)
|
|
||||||
projected_decoder_out = joiner_decoder_proj.run(
|
|
||||||
[joiner_decoder_proj.get_outputs()[0].name],
|
|
||||||
{joiner_decoder_proj.get_inputs()[0].name: decoder_out},
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
projected_decoder_out = torch.from_numpy(projected_decoder_out)
|
|
||||||
|
|
||||||
offset = 0
|
|
||||||
for batch_size in batch_size_list:
|
|
||||||
start = offset
|
|
||||||
end = offset + batch_size
|
|
||||||
current_encoder_out = projected_encoder_out[start:end]
|
|
||||||
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
|
||||||
offset = end
|
|
||||||
|
|
||||||
projected_decoder_out = projected_decoder_out[:batch_size]
|
|
||||||
|
|
||||||
logits = joiner.run(
|
|
||||||
[joiner_output_nodes[0].name],
|
|
||||||
{
|
|
||||||
joiner_input_nodes[0].name: np.expand_dims(
|
|
||||||
np.expand_dims(current_encoder_out, axis=1), axis=1
|
|
||||||
),
|
|
||||||
joiner_input_nodes[1]
|
|
||||||
.name: projected_decoder_out.unsqueeze(1)
|
|
||||||
.unsqueeze(1)
|
|
||||||
.numpy(),
|
|
||||||
},
|
|
||||||
)[0]
|
|
||||||
logits = torch.from_numpy(logits).squeeze(1).squeeze(1)
|
|
||||||
# logits'shape (batch_size, vocab_size)
|
|
||||||
|
|
||||||
assert logits.ndim == 2, logits.shape
|
|
||||||
y = logits.argmax(dim=1).tolist()
|
|
||||||
emitted = False
|
|
||||||
for i, v in enumerate(y):
|
|
||||||
if v != blank_id:
|
|
||||||
hyps[i].append(v)
|
|
||||||
emitted = True
|
|
||||||
if emitted:
|
|
||||||
# update decoder output
|
|
||||||
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
|
||||||
decoder_input = torch.tensor(
|
|
||||||
decoder_input,
|
|
||||||
dtype=torch.int64,
|
|
||||||
)
|
|
||||||
decoder_out = decoder.run(
|
|
||||||
[decoder_output_nodes[0].name],
|
|
||||||
{
|
|
||||||
decoder_input_nodes[0].name: decoder_input.numpy(),
|
|
||||||
},
|
|
||||||
)[0].squeeze(1)
|
|
||||||
projected_decoder_out = joiner_decoder_proj.run(
|
|
||||||
[joiner_decoder_proj.get_outputs()[0].name],
|
|
||||||
{joiner_decoder_proj.get_inputs()[0].name: decoder_out},
|
|
||||||
)[0]
|
|
||||||
projected_decoder_out = torch.from_numpy(projected_decoder_out)
|
|
||||||
|
|
||||||
sorted_ans = [h[context_size:] for h in hyps]
|
|
||||||
ans = []
|
|
||||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
|
||||||
for i in range(N):
|
|
||||||
ans.append(sorted_ans[unsorted_indices[i]])
|
|
||||||
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def main():
|
|
||||||
parser = get_parser()
|
|
||||||
args = parser.parse_args()
|
|
||||||
logging.info(vars(args))
|
|
||||||
|
|
||||||
session_opts = ort.SessionOptions()
|
|
||||||
session_opts.inter_op_num_threads = 1
|
|
||||||
session_opts.intra_op_num_threads = 1
|
|
||||||
|
|
||||||
encoder = ort.InferenceSession(
|
|
||||||
args.encoder_model_filename,
|
|
||||||
sess_options=session_opts,
|
|
||||||
)
|
|
||||||
|
|
||||||
decoder = ort.InferenceSession(
|
|
||||||
args.decoder_model_filename,
|
|
||||||
sess_options=session_opts,
|
|
||||||
)
|
|
||||||
|
|
||||||
joiner = ort.InferenceSession(
|
|
||||||
args.joiner_model_filename,
|
|
||||||
sess_options=session_opts,
|
|
||||||
)
|
|
||||||
|
|
||||||
joiner_encoder_proj = ort.InferenceSession(
|
|
||||||
args.joiner_encoder_proj_model_filename,
|
|
||||||
sess_options=session_opts,
|
|
||||||
)
|
|
||||||
|
|
||||||
joiner_decoder_proj = ort.InferenceSession(
|
|
||||||
args.joiner_decoder_proj_model_filename,
|
|
||||||
sess_options=session_opts,
|
|
||||||
)
|
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
|
||||||
sp.load(args.bpe_model)
|
|
||||||
|
|
||||||
logging.info("Constructing Fbank computer")
|
|
||||||
opts = kaldifeat.FbankOptions()
|
|
||||||
opts.device = "cpu"
|
|
||||||
opts.frame_opts.dither = 0
|
|
||||||
opts.frame_opts.snip_edges = False
|
|
||||||
opts.frame_opts.samp_freq = args.sample_rate
|
|
||||||
opts.mel_opts.num_bins = 80
|
|
||||||
|
|
||||||
fbank = kaldifeat.Fbank(opts)
|
|
||||||
|
|
||||||
logging.info(f"Reading sound files: {args.sound_files}")
|
|
||||||
waves = read_sound_files(
|
|
||||||
filenames=args.sound_files,
|
|
||||||
expected_sample_rate=args.sample_rate,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("Decoding started")
|
|
||||||
features = fbank(waves)
|
|
||||||
feature_lengths = [f.size(0) for f in features]
|
|
||||||
|
|
||||||
features = pad_sequence(
|
|
||||||
features,
|
|
||||||
batch_first=True,
|
|
||||||
padding_value=math.log(1e-10),
|
|
||||||
)
|
|
||||||
|
|
||||||
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
|
|
||||||
|
|
||||||
encoder_input_nodes = encoder.get_inputs()
|
|
||||||
encoder_out_nodes = encoder.get_outputs()
|
|
||||||
encoder_out, encoder_out_lens = encoder.run(
|
|
||||||
[encoder_out_nodes[0].name, encoder_out_nodes[1].name],
|
|
||||||
{
|
|
||||||
encoder_input_nodes[0].name: features.numpy(),
|
|
||||||
encoder_input_nodes[1].name: feature_lengths.numpy(),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
hyps = greedy_search(
|
|
||||||
decoder=decoder,
|
|
||||||
joiner=joiner,
|
|
||||||
joiner_encoder_proj=joiner_encoder_proj,
|
|
||||||
joiner_decoder_proj=joiner_decoder_proj,
|
|
||||||
encoder_out=encoder_out,
|
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
context_size=args.context_size,
|
|
||||||
)
|
|
||||||
s = "\n"
|
|
||||||
for filename, hyp in zip(args.sound_files, hyps):
|
|
||||||
words = sp.decode(hyp)
|
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
|
||||||
main()
|
|
@ -0,0 +1 @@
|
|||||||
|
../pruned_transducer_stateless3/onnx_pretrained.py
|
@ -44,7 +44,7 @@ from scaling import (
|
|||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from icefall.dist import get_rank
|
from icefall.dist import get_rank
|
||||||
from icefall.utils import make_pad_mask
|
from icefall.utils import is_jit_tracing, make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
class Zipformer(EncoderInterface):
|
class Zipformer(EncoderInterface):
|
||||||
@ -792,6 +792,7 @@ class AttentionDownsample(torch.nn.Module):
|
|||||||
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
|
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
|
||||||
scores = (src * self.query).sum(dim=-1, keepdim=True)
|
scores = (src * self.query).sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||||
scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04)
|
scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04)
|
||||||
|
|
||||||
weights = scores.softmax(dim=1)
|
weights = scores.softmax(dim=1)
|
||||||
@ -904,6 +905,13 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
|
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
|
||||||
"""Construct a PositionalEncoding object."""
|
"""Construct a PositionalEncoding object."""
|
||||||
super(RelPositionalEncoding, self).__init__()
|
super(RelPositionalEncoding, self).__init__()
|
||||||
|
if is_jit_tracing():
|
||||||
|
# 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e.,
|
||||||
|
# It assumes that the maximum input won't have more than
|
||||||
|
# 10k frames.
|
||||||
|
#
|
||||||
|
# TODO(fangjun): Use torch.jit.script() for this module
|
||||||
|
max_len = 10000
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||||
self.pe = None
|
self.pe = None
|
||||||
|
@ -306,11 +306,11 @@ def export_encoder_model_onnx(
|
|||||||
left_context_len = ",".join(map(str, left_context_len))
|
left_context_len = ",".join(map(str, left_context_len))
|
||||||
|
|
||||||
meta_data = {
|
meta_data = {
|
||||||
"model_type": "streaming_zipformer",
|
"model_type": "zipformer",
|
||||||
"version": "1",
|
"version": "1",
|
||||||
"model_author": "k2-fsa",
|
"model_author": "k2-fsa",
|
||||||
"decode_chunk_len": str(decode_chunk_len), # 32
|
"decode_chunk_len": str(decode_chunk_len), # 32
|
||||||
"pad_length": str(pad_length), # 7
|
"T": str(T), # 39
|
||||||
"num_encoder_layers": num_encoder_layers,
|
"num_encoder_layers": num_encoder_layers,
|
||||||
"encoder_dims": encoder_dims,
|
"encoder_dims": encoder_dims,
|
||||||
"attention_dims": attention_dims,
|
"attention_dims": attention_dims,
|
||||||
@ -362,8 +362,8 @@ def export_encoder_model_onnx(
|
|||||||
input_names=input_names,
|
input_names=input_names,
|
||||||
output_names=output_names,
|
output_names=output_names,
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"x": {0: "N", 1: "T"},
|
"x": {0: "N"},
|
||||||
"encoder_out": {0: "N", 1: "T"},
|
"encoder_out": {0: "N"},
|
||||||
**inputs,
|
**inputs,
|
||||||
**outputs,
|
**outputs,
|
||||||
},
|
},
|
||||||
|
@ -136,8 +136,11 @@ class OnnxModel:
|
|||||||
def init_encoder_states(self, batch_size: int = 1):
|
def init_encoder_states(self, batch_size: int = 1):
|
||||||
encoder_meta = self.encoder.get_modelmeta().custom_metadata_map
|
encoder_meta = self.encoder.get_modelmeta().custom_metadata_map
|
||||||
|
|
||||||
|
model_type = encoder_meta["model_type"]
|
||||||
|
assert model_type == "zipformer", model_type
|
||||||
|
|
||||||
decode_chunk_len = int(encoder_meta["decode_chunk_len"])
|
decode_chunk_len = int(encoder_meta["decode_chunk_len"])
|
||||||
pad_length = int(encoder_meta["pad_length"])
|
T = int(encoder_meta["T"])
|
||||||
|
|
||||||
num_encoder_layers = encoder_meta["num_encoder_layers"]
|
num_encoder_layers = encoder_meta["num_encoder_layers"]
|
||||||
encoder_dims = encoder_meta["encoder_dims"]
|
encoder_dims = encoder_meta["encoder_dims"]
|
||||||
@ -155,7 +158,7 @@ class OnnxModel:
|
|||||||
left_context_len = to_int_list(left_context_len)
|
left_context_len = to_int_list(left_context_len)
|
||||||
|
|
||||||
logging.info(f"decode_chunk_len: {decode_chunk_len}")
|
logging.info(f"decode_chunk_len: {decode_chunk_len}")
|
||||||
logging.info(f"pad_length: {pad_length}")
|
logging.info(f"T: {T}")
|
||||||
logging.info(f"num_encoder_layers: {num_encoder_layers}")
|
logging.info(f"num_encoder_layers: {num_encoder_layers}")
|
||||||
logging.info(f"encoder_dims: {encoder_dims}")
|
logging.info(f"encoder_dims: {encoder_dims}")
|
||||||
logging.info(f"attention_dims: {attention_dims}")
|
logging.info(f"attention_dims: {attention_dims}")
|
||||||
@ -219,7 +222,7 @@ class OnnxModel:
|
|||||||
|
|
||||||
self.num_encoders = num_encoders
|
self.num_encoders = num_encoders
|
||||||
|
|
||||||
self.segment = decode_chunk_len + pad_length
|
self.segment = T
|
||||||
self.offset = decode_chunk_len
|
self.offset = decode_chunk_len
|
||||||
|
|
||||||
def init_decoder(self, decoder_model_filename: str):
|
def init_decoder(self, decoder_model_filename: str):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user