Add scripts to export streaming zipformer(v1) to RKNN (#1882)

This commit is contained in:
Fangjun Kuang 2025-02-27 17:10:58 +08:00 committed by GitHub
parent 2ba665abca
commit db9fb8ad31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 1155 additions and 31 deletions

View File

@ -10,7 +10,17 @@ def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--min-torch-version",
help="Minimu torch version",
help="torch version",
)
parser.add_argument(
"--torch-version",
help="torch version",
)
parser.add_argument(
"--python-version",
help="python version",
)
return parser.parse_args()
@ -52,7 +62,7 @@ def get_torchaudio_version(torch_version):
return torch_version
def get_matrix(min_torch_version):
def get_matrix(min_torch_version, specified_torch_version, specified_python_version):
k2_version = "1.24.4.dev20241029"
kaldifeat_version = "1.25.5.dev20241029"
version = "20241218"
@ -71,6 +81,12 @@ def get_matrix(min_torch_version):
torch_version += ["2.5.0"]
torch_version += ["2.5.1"]
if specified_torch_version:
torch_version = [specified_torch_version]
if specified_python_version:
python_version = [specified_python_version]
matrix = []
for p in python_version:
for t in torch_version:
@ -115,7 +131,11 @@ def get_matrix(min_torch_version):
def main():
args = get_args()
matrix = get_matrix(min_torch_version=args.min_torch_version)
matrix = get_matrix(
min_torch_version=args.min_torch_version,
specified_torch_version=args.torch_version,
specified_python_version=args.python_version,
)
print(json.dumps({"include": matrix}))

200
.github/scripts/librispeech/ASR/run_rknn.sh vendored Executable file
View File

@ -0,0 +1,200 @@
#!/usr/bin/env bash
set -ex
python3 -m pip install kaldi-native-fbank soundfile librosa
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/librispeech/ASR
# https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed
# sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20
function export_bilingual_zh_en() {
d=exp_zh_en
mkdir $d
pushd $d
curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/exp/pretrained.pt
mv pretrained.pt epoch-99.pt
curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/data/lang_char_bpe/tokens.txt
curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/test_wavs/0.wav
curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/test_wavs/1.wav
curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/test_wavs/2.wav
curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/test_wavs/3.wav
curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/test_wavs/4.wav
ls -lh
popd
./pruned_transducer_stateless7_streaming/export-onnx-zh.py \
--dynamic-batch 0 \
--enable-int8-quantization 0 \
--tokens $d/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $d/ \
--decode-chunk-len 64 \
--num-encoder-layers "2,4,3,2,4" \
--feedforward-dims "1024,1024,1536,1536,1024" \
--nhead "8,8,8,8,8" \
--encoder-dims "384,384,384,384,384" \
--attention-dims "192,192,192,192,192" \
--encoder-unmasked-dims "256,256,256,256,256" \
--zipformer-downsampling-factors "1,2,4,8,2" \
--cnn-module-kernels "31,31,31,31,31" \
--decoder-dim 512 \
--joiner-dim 512
ls -lh $d/
./pruned_transducer_stateless7_streaming/onnx_pretrained.py \
--encoder-model-filename $d/encoder-epoch-99-avg-1.onnx \
--decoder-model-filename $d/decoder-epoch-99-avg-1.onnx \
--joiner-model-filename $d/joiner-epoch-99-avg-1.onnx \
--tokens $d/tokens.txt \
$d/0.wav
./pruned_transducer_stateless7_streaming/onnx_pretrained.py \
--encoder-model-filename $d/encoder-epoch-99-avg-1.onnx \
--decoder-model-filename $d/decoder-epoch-99-avg-1.onnx \
--joiner-model-filename $d/joiner-epoch-99-avg-1.onnx \
--tokens $d/tokens.txt \
$d/1.wav
mkdir -p /icefall/rknn-models
for platform in rk3562 rk3566 rk3568 rk3576 rk3588; do
mkdir -p $platform
./pruned_transducer_stateless7_streaming/export_rknn.py \
--in-encoder $d/encoder-epoch-99-avg-1.onnx \
--in-decoder $d/decoder-epoch-99-avg-1.onnx \
--in-joiner $d/joiner-epoch-99-avg-1.onnx \
--out-encoder $platform/encoder.rknn \
--out-decoder $platform/decoder.rknn \
--out-joiner $platform/joiner.rknn \
--target-platform $platform 2>/dev/null
ls -lh $platform/
./pruned_transducer_stateless7_streaming/test_rknn_on_cpu_simulator.py \
--encoder $d/encoder-epoch-99-avg-1.onnx \
--decoder $d/decoder-epoch-99-avg-1.onnx \
--joiner $d/joiner-epoch-99-avg-1.onnx \
--tokens $d/tokens.txt \
--wav $d/0.wav
cp $d/tokens.txt $platform
cp $d/*.wav $platform
cp -av $platform /icefall/rknn-models
done
ls -lh /icefall/rknn-models
}
# https://huggingface.co/csukuangfj/k2fsa-zipformer-bilingual-zh-en-t
# sherpa-onnx-streaming-zipformer-small-bilingual-zh-en-2023-02-16
function export_bilingual_zh_en_small() {
d=exp_zh_en_small
mkdir $d
pushd $d
curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-bilingual-zh-en-t/resolve/main/exp/pretrained.pt
mv pretrained.pt epoch-99.pt
curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-bilingual-zh-en-t/resolve/main/data/lang_char_bpe/tokens.txt
curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-bilingual-zh-en-t/resolve/main/test_wavs/0.wav
curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-bilingual-zh-en-t/resolve/main/test_wavs/1.wav
curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-bilingual-zh-en-t/resolve/main/test_wavs/2.wav
curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-bilingual-zh-en-t/resolve/main/test_wavs/3.wav
curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-bilingual-zh-en-t/resolve/main/test_wavs/4.wav
ls -lh
popd
./pruned_transducer_stateless7_streaming/export-onnx-zh.py \
--dynamic-batch 0 \
--enable-int8-quantization 0 \
--tokens $d/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $d/ \
--decode-chunk-len 64 \
\
--num-encoder-layers 2,2,2,2,2 \
--feedforward-dims 768,768,768,768,768 \
--nhead 4,4,4,4,4 \
--encoder-dims 256,256,256,256,256 \
--attention-dims 192,192,192,192,192 \
--encoder-unmasked-dims 192,192,192,192,192 \
\
--zipformer-downsampling-factors "1,2,4,8,2" \
--cnn-module-kernels "31,31,31,31,31" \
--decoder-dim 512 \
--joiner-dim 512
ls -lh $d/
./pruned_transducer_stateless7_streaming/onnx_pretrained.py \
--encoder-model-filename $d/encoder-epoch-99-avg-1.onnx \
--decoder-model-filename $d/decoder-epoch-99-avg-1.onnx \
--joiner-model-filename $d/joiner-epoch-99-avg-1.onnx \
--tokens $d/tokens.txt \
$d/0.wav
./pruned_transducer_stateless7_streaming/onnx_pretrained.py \
--encoder-model-filename $d/encoder-epoch-99-avg-1.onnx \
--decoder-model-filename $d/decoder-epoch-99-avg-1.onnx \
--joiner-model-filename $d/joiner-epoch-99-avg-1.onnx \
--tokens $d/tokens.txt \
$d/1.wav
mkdir -p /icefall/rknn-models-small
for platform in rk3562 rk3566 rk3568 rk3576 rk3588; do
mkdir -p $platform
./pruned_transducer_stateless7_streaming/export_rknn.py \
--in-encoder $d/encoder-epoch-99-avg-1.onnx \
--in-decoder $d/decoder-epoch-99-avg-1.onnx \
--in-joiner $d/joiner-epoch-99-avg-1.onnx \
--out-encoder $platform/encoder.rknn \
--out-decoder $platform/decoder.rknn \
--out-joiner $platform/joiner.rknn \
--target-platform $platform 2>/dev/null
ls -lh $platform/
./pruned_transducer_stateless7_streaming/test_rknn_on_cpu_simulator.py \
--encoder $d/encoder-epoch-99-avg-1.onnx \
--decoder $d/decoder-epoch-99-avg-1.onnx \
--joiner $d/joiner-epoch-99-avg-1.onnx \
--tokens $d/tokens.txt \
--wav $d/0.wav
cp $d/tokens.txt $platform
cp $d/*.wav $platform
cp -av $platform /icefall/rknn-models-small
done
ls -lh /icefall/rknn-models-small
}
export_bilingual_zh_en_small
export_bilingual_zh_en

180
.github/workflows/rknn.yml vendored Normal file
View File

@ -0,0 +1,180 @@
name: rknn
on:
push:
branches:
- master
- ci-rknn-2
pull_request:
branches:
- master
workflow_dispatch:
concurrency:
group: rknn-${{ github.ref }}
cancel-in-progress: true
jobs:
generate_build_matrix:
if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa'
# see https://github.com/pytorch/pytorch/pull/50633
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Generating build matrix
id: set-matrix
run: |
# outputting for debugging purposes
python ./.github/scripts/docker/generate_build_matrix.py --torch-version=2.4.0 --python-version=3.10
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --torch-version=2.4.0 --python-version=3.10)
echo "::set-output name=matrix::${MATRIX}"
rknn:
needs: generate_build_matrix
name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Python
if: false
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Export ONNX model
uses: addnab/docker-run-action@v3
with:
image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }}
options: |
--volume ${{ github.workspace }}/:/icefall
shell: bash
run: |
cat /etc/*release
lsb_release -a
uname -a
python3 --version
export PYTHONPATH=/icefall:$PYTHONPATH
cd /icefall
git config --global --add safe.directory /icefall
python3 -m torch.utils.collect_env
python3 -m k2.version
pip list
# Install rknn
curl -SL -O https://huggingface.co/csukuangfj/rknn-toolkit2/resolve/main/rknn_toolkit2-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ./*.whl "numpy<=1.26.4"
pip list | grep rknn
echo "---"
pip list
echo "---"
.github/scripts/librispeech/ASR/run_rknn.sh
- name: Display rknn models
shell: bash
run: |
ls -lh
ls -lh rknn-models/*
echo "----"
ls -lh rknn-models-small/*
- name: Collect results (small)
shell: bash
run: |
for platform in rk3562 rk3566 rk3568 rk3576 rk3588; do
dst=sherpa-onnx-$platform-streaming-zipformer-small-bilingual-zh-en-2023-02-16
mkdir $dst
mkdir $dst/test_wavs
src=rknn-models-small/$platform
cp -v $src/*.rknn $dst/
cp -v $src/tokens.txt $dst/
cp -v $src/*.wav $dst/test_wavs/
ls -lh $dst
tar cjfv $dst.tar.bz2 $dst
rm -rf $dst
done
- name: Collect results
shell: bash
run: |
for platform in rk3562 rk3566 rk3568 rk3576 rk3588; do
dst=sherpa-onnx-$platform-streaming-zipformer-bilingual-zh-en-2023-02-20
mkdir $dst
mkdir $dst/test_wavs
src=rknn-models/$platform
cp -v $src/*.rknn $dst/
cp -v $src/tokens.txt $dst/
cp -v $src/*.wav $dst/test_wavs/
ls -lh $dst
tar cjfv $dst.tar.bz2 $dst
rm -rf $dst
done
- name: Display results
shell: bash
run: |
ls -lh *rk*.tar.bz2
- name: Release to GitHub
uses: svenstaro/upload-release-action@v2
with:
file_glob: true
overwrite: true
file: sherpa-onnx-*.tar.bz2
repo_name: k2-fsa/sherpa-onnx
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
tag: asr-models
- name: Upload model to huggingface
if: github.event_name == 'push'
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
uses: nick-fields/retry@v3
with:
max_attempts: 20
timeout_seconds: 200
shell: bash
command: |
git config --global user.email "csukuangfj@gmail.com"
git config --global user.name "Fangjun Kuang"
rm -rf huggingface
export GIT_LFS_SKIP_SMUDGE=1
git clone https://huggingface.co/csukuangfj/sherpa-onnx-rknn-models huggingface
cd huggingface
git fetch
git pull
git merge -m "merge remote" --ff origin main
dst=streaming-asr
mkdir -p $dst
rm -fv $dst/*
cp ../*rk*.tar.bz2 $dst/
ls -lh $dst
git add .
git status
git commit -m "update models"
git status
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-rknn-models main || true
rm -rf huggingface

View File

@ -85,6 +85,20 @@ def get_parser():
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--dynamic-batch",
type=int,
default=1,
help="1 to support dynamic batch size. 0 to support only batch size == 1",
)
parser.add_argument(
"--enable-int8-quantization",
type=int,
default=1,
help="1 to also export int8 onnx models.",
)
parser.add_argument(
"--epoch",
type=int,
@ -257,6 +271,7 @@ def export_encoder_model_onnx(
encoder_model: OnnxEncoder,
encoder_filename: str,
opset_version: int = 11,
dynamic_batch: bool = True,
) -> None:
"""
Onnx model inputs:
@ -274,6 +289,8 @@ def export_encoder_model_onnx(
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
dynamic_batch:
True to export a model supporting dynamic batch size
"""
encoder_model.encoder.__class__.forward = (
@ -379,7 +396,9 @@ def export_encoder_model_onnx(
"encoder_out": {0: "N"},
**inputs,
**outputs,
},
}
if dynamic_batch
else {},
)
add_meta_data(filename=encoder_filename, meta_data=meta_data)
@ -389,6 +408,7 @@ def export_decoder_model_onnx(
decoder_model: nn.Module,
decoder_filename: str,
opset_version: int = 11,
dynamic_batch: bool = True,
) -> None:
"""Export the decoder model to ONNX format.
@ -412,7 +432,7 @@ def export_decoder_model_onnx(
"""
context_size = decoder_model.decoder.context_size
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
y = torch.zeros(1, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
@ -425,7 +445,9 @@ def export_decoder_model_onnx(
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
}
if dynamic_batch
else {},
)
meta_data = {
"context_size": str(context_size),
@ -438,6 +460,7 @@ def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
dynamic_batch: bool = True,
) -> None:
"""Export the joiner model to ONNX format.
The exported joiner model has two inputs:
@ -452,8 +475,8 @@ def export_joiner_model_onnx(
joiner_dim = joiner_model.output_linear.weight.shape[1]
logging.info(f"joiner dim: {joiner_dim}")
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model,
@ -470,7 +493,9 @@ def export_joiner_model_onnx(
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
}
if dynamic_batch
else {},
)
meta_data = {
"joiner_dim": str(joiner_dim),
@ -629,6 +654,7 @@ def main():
encoder,
encoder_filename,
opset_version=opset_version,
dynamic_batch=params.dynamic_batch == 1,
)
logging.info(f"Exported encoder to {encoder_filename}")
@ -638,6 +664,7 @@ def main():
decoder,
decoder_filename,
opset_version=opset_version,
dynamic_batch=params.dynamic_batch == 1,
)
logging.info(f"Exported decoder to {decoder_filename}")
@ -647,37 +674,39 @@ def main():
joiner,
joiner_filename,
opset_version=opset_version,
dynamic_batch=params.dynamic_batch == 1,
)
logging.info(f"Exported joiner to {joiner_filename}")
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
logging.info("Generate int8 quantization models")
if params.enable_int8_quantization:
logging.info("Generate int8 quantization models")
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=encoder_filename,
model_output=encoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=encoder_filename,
model_output=encoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul", "Gather"],
weight_type=QuantType.QInt8,
)
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul", "Gather"],
weight_type=QuantType.QInt8,
)
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
quantize_dynamic(
model_input=joiner_filename,
model_output=joiner_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
quantize_dynamic(
model_input=joiner_filename,
model_output=joiner_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
if __name__ == "__main__":

View File

@ -0,0 +1,261 @@
#!/usr/bin/env python3
# Copyright (c) 2025 Xiaomi Corporation (authors: Fangjun Kuang)
import argparse
import logging
from pathlib import Path
from typing import List
from rknn.api import RKNN
logging.basicConfig(level=logging.WARNING)
g_platforms = [
# "rv1103",
# "rv1103b",
# "rv1106",
# "rk2118",
"rk3562",
"rk3566",
"rk3568",
"rk3576",
"rk3588",
]
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--target-platform",
type=str,
required=True,
help=f"Supported values are: {','.join(g_platforms)}",
)
parser.add_argument(
"--in-encoder",
type=str,
required=True,
help="Path to the encoder onnx model",
)
parser.add_argument(
"--in-decoder",
type=str,
required=True,
help="Path to the decoder onnx model",
)
parser.add_argument(
"--in-joiner",
type=str,
required=True,
help="Path to the joiner onnx model",
)
parser.add_argument(
"--out-encoder",
type=str,
required=True,
help="Path to the encoder rknn model",
)
parser.add_argument(
"--out-decoder",
type=str,
required=True,
help="Path to the decoder rknn model",
)
parser.add_argument(
"--out-joiner",
type=str,
required=True,
help="Path to the joiner rknn model",
)
return parser
def export_rknn(rknn, filename):
ret = rknn.export_rknn(filename)
if ret != 0:
exit("Export rknn model to {filename} failed!")
def init_model(filename: str, target_platform: str, custom_string=None):
rknn = RKNN(verbose=False)
rknn.config(target_platform=target_platform, custom_string=custom_string)
if not Path(filename).is_file():
exit(f"{filename} does not exist")
ret = rknn.load_onnx(model=filename)
if ret != 0:
exit(f"Load model {filename} failed!")
ret = rknn.build(do_quantization=False)
if ret != 0:
exit("Build model {filename} failed!")
return rknn
class MetaData:
def __init__(
self,
model_type: str,
attention_dims: List[int],
encoder_dims: List[int],
T: int,
left_context_len: List[int],
decode_chunk_len: int,
cnn_module_kernels: List[int],
num_encoder_layers: List[int],
context_size: int,
):
self.model_type = model_type
self.attention_dims = attention_dims
self.encoder_dims = encoder_dims
self.T = T
self.left_context_len = left_context_len
self.decode_chunk_len = decode_chunk_len
self.cnn_module_kernels = cnn_module_kernels
self.num_encoder_layers = num_encoder_layers
self.context_size = context_size
def __str__(self) -> str:
return self.to_str()
def to_str(self) -> str:
def to_s(ll):
return ",".join(list(map(str, ll)))
s = f"model_type={self.model_type}"
s += ";attention_dims=" + to_s(self.attention_dims)
s += ";encoder_dims=" + to_s(self.encoder_dims)
s += ";T=" + str(self.T)
s += ";left_context_len=" + to_s(self.left_context_len)
s += ";decode_chunk_len=" + str(self.decode_chunk_len)
s += ";cnn_module_kernels=" + to_s(self.cnn_module_kernels)
s += ";num_encoder_layers=" + to_s(self.num_encoder_layers)
s += ";context_size=" + str(self.context_size)
assert len(s) < 1024, (s, len(s))
return s
def get_meta_data(encoder: str, decoder: str):
import onnxruntime
session_opts = onnxruntime.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
m_encoder = onnxruntime.InferenceSession(
encoder,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)
m_decoder = onnxruntime.InferenceSession(
decoder,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)
encoder_meta = m_encoder.get_modelmeta().custom_metadata_map
print(encoder_meta)
# {'attention_dims': '192,192,192,192,192', 'version': '1',
# 'model_type': 'zipformer', 'encoder_dims': '256,256,256,256,256',
# 'model_author': 'k2-fsa', 'T': '103',
# 'left_context_len': '192,96,48,24,96',
# 'decode_chunk_len': '96',
# 'cnn_module_kernels': '31,31,31,31,31',
# 'num_encoder_layers': '2,2,2,2,2'}
def to_int_list(s):
return list(map(int, s.split(",")))
decoder_meta = m_decoder.get_modelmeta().custom_metadata_map
print(decoder_meta)
model_type = encoder_meta["model_type"]
attention_dims = to_int_list(encoder_meta["attention_dims"])
encoder_dims = to_int_list(encoder_meta["encoder_dims"])
T = int(encoder_meta["T"])
left_context_len = to_int_list(encoder_meta["left_context_len"])
decode_chunk_len = int(encoder_meta["decode_chunk_len"])
cnn_module_kernels = to_int_list(encoder_meta["cnn_module_kernels"])
num_encoder_layers = to_int_list(encoder_meta["num_encoder_layers"])
context_size = int(decoder_meta["context_size"])
return MetaData(
model_type=model_type,
attention_dims=attention_dims,
encoder_dims=encoder_dims,
T=T,
left_context_len=left_context_len,
decode_chunk_len=decode_chunk_len,
cnn_module_kernels=cnn_module_kernels,
num_encoder_layers=num_encoder_layers,
context_size=context_size,
)
class RKNNModel:
def __init__(
self,
encoder: str,
decoder: str,
joiner: str,
target_platform: str,
):
self.meta = get_meta_data(encoder, decoder)
self.encoder = init_model(
encoder,
custom_string=self.meta.to_str(),
target_platform=target_platform,
)
self.decoder = init_model(decoder, target_platform=target_platform)
self.joiner = init_model(joiner, target_platform=target_platform)
def export_rknn(self, encoder, decoder, joiner):
export_rknn(self.encoder, encoder)
export_rknn(self.decoder, decoder)
export_rknn(self.joiner, joiner)
def release(self):
self.encoder.release()
self.decoder.release()
self.joiner.release()
def main():
args = get_parser().parse_args()
print(vars(args))
model = RKNNModel(
encoder=args.in_encoder,
decoder=args.in_decoder,
joiner=args.in_joiner,
target_platform=args.target_platform,
)
print(model.meta)
model.export_rknn(
encoder=args.out_encoder,
decoder=args.out_decoder,
joiner=args.out_joiner,
)
model.release()
if __name__ == "__main__":
main()

View File

@ -132,10 +132,18 @@ class OnnxModel:
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
print("==========Encoder input==========")
for i in self.encoder.get_inputs():
print(i)
print("==========Encoder output==========")
for i in self.encoder.get_outputs():
print(i)
self.init_encoder_states()
def init_encoder_states(self, batch_size: int = 1):
encoder_meta = self.encoder.get_modelmeta().custom_metadata_map
print(encoder_meta)
model_type = encoder_meta["model_type"]
assert model_type == "zipformer", model_type
@ -232,6 +240,12 @@ class OnnxModel:
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
print("==========Decoder input==========")
for i in self.decoder.get_inputs():
print(i)
print("==========Decoder output==========")
for i in self.decoder.get_outputs():
print(i)
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
self.context_size = int(decoder_meta["context_size"])
@ -247,6 +261,13 @@ class OnnxModel:
providers=["CPUExecutionProvider"],
)
print("==========Joiner input==========")
for i in self.joiner.get_inputs():
print(i)
print("==========Joiner output==========")
for i in self.joiner.get_outputs():
print(i)
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
self.joiner_dim = int(joiner_meta["joiner_dim"])

View File

@ -0,0 +1,413 @@
#!/usr/bin/env python3
# Copyright (c) 2025 Xiaomi Corporation (authors: Fangjun Kuang)
import argparse
from pathlib import Path
from typing import List, Tuple
import kaldi_native_fbank as knf
import numpy as np
import soundfile as sf
from rknn.api import RKNN
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder",
type=str,
required=True,
help="Path to the encoder onnx model",
)
parser.add_argument(
"--decoder",
type=str,
required=True,
help="Path to the decoder onnx model",
)
parser.add_argument(
"--joiner",
type=str,
required=True,
help="Path to the joiner onnx model",
)
parser.add_argument(
"--tokens",
type=str,
required=True,
help="Path to the tokens.txt",
)
parser.add_argument(
"--wav",
type=str,
required=True,
help="Path to test wave",
)
return parser
def load_audio(filename: str) -> Tuple[np.ndarray, int]:
data, sample_rate = sf.read(
filename,
always_2d=True,
dtype="float32",
)
data = data[:, 0] # use only the first channel
samples = np.ascontiguousarray(data)
return samples, sample_rate
def compute_features(filename: str, dim: int = 80) -> np.ndarray:
"""
Args:
filename:
Path to an audio file.
Returns:
Return a 1-D float32 tensor of shape (1, 80, 3000) containing the features.
"""
wave, sample_rate = load_audio(filename)
if sample_rate != 16000:
import librosa
wave = librosa.resample(wave, orig_sr=sample_rate, target_sr=16000)
sample_rate = 16000
features = []
opts = knf.FbankOptions()
opts.frame_opts.dither = 0
opts.mel_opts.num_bins = dim
opts.frame_opts.snip_edges = False
fbank = knf.OnlineFbank(opts)
fbank.accept_waveform(16000, wave)
tail_paddings = np.zeros(int(0.5 * 16000), dtype=np.float32)
fbank.accept_waveform(16000, tail_paddings)
fbank.input_finished()
for i in range(fbank.num_frames_ready):
f = fbank.get_frame(i)
features.append(f)
features = np.stack(features, axis=0)
return features
def load_tokens(filename):
tokens = dict()
with open(filename, "r") as f:
for line in f:
t, i = line.split()
tokens[int(i)] = t
return tokens
def init_model(filename, target_platform="rk3588", custom_string=None):
rknn = RKNN(verbose=False)
rknn.config(target_platform=target_platform, custom_string=custom_string)
if not Path(filename).is_file():
exit(f"{filename} does not exist")
ret = rknn.load_onnx(model=filename)
if ret != 0:
exit(f"Load model {filename} failed!")
ret = rknn.build(do_quantization=False)
if ret != 0:
exit("Build model {filename} failed!")
ret = rknn.init_runtime()
if ret != 0:
exit(f"Failed to init rknn runtime for {filename}")
return rknn
class MetaData:
def __init__(
self,
model_type: str,
attention_dims: List[int],
encoder_dims: List[int],
T: int,
left_context_len: List[int],
decode_chunk_len: int,
cnn_module_kernels: List[int],
num_encoder_layers: List[int],
):
self.model_type = model_type
self.attention_dims = attention_dims
self.encoder_dims = encoder_dims
self.T = T
self.left_context_len = left_context_len
self.decode_chunk_len = decode_chunk_len
self.cnn_module_kernels = cnn_module_kernels
self.num_encoder_layers = num_encoder_layers
def __str__(self) -> str:
return self.to_str()
def to_str(self) -> str:
def to_s(ll):
return ",".join(list(map(str, ll)))
s = f"model_type={self.model_type}"
s += ";attention_dims=" + to_s(self.attention_dims)
s += ";encoder_dims=" + to_s(self.encoder_dims)
s += ";T=" + str(self.T)
s += ";left_context_len=" + to_s(self.left_context_len)
s += ";decode_chunk_len=" + str(self.decode_chunk_len)
s += ";cnn_module_kernels=" + to_s(self.cnn_module_kernels)
s += ";num_encoder_layers=" + to_s(self.num_encoder_layers)
assert len(s) < 1024, (s, len(s))
return s
def get_meta_data(encoder: str):
import onnxruntime
session_opts = onnxruntime.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
m = onnxruntime.InferenceSession(
encoder,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)
meta = m.get_modelmeta().custom_metadata_map
print(meta)
# {'attention_dims': '192,192,192,192,192', 'version': '1',
# 'model_type': 'zipformer', 'encoder_dims': '256,256,256,256,256',
# 'model_author': 'k2-fsa', 'T': '103',
# 'left_context_len': '192,96,48,24,96',
# 'decode_chunk_len': '96',
# 'cnn_module_kernels': '31,31,31,31,31',
# 'num_encoder_layers': '2,2,2,2,2'}
def to_int_list(s):
return list(map(int, s.split(",")))
model_type = meta["model_type"]
attention_dims = to_int_list(meta["attention_dims"])
encoder_dims = to_int_list(meta["encoder_dims"])
T = int(meta["T"])
left_context_len = to_int_list(meta["left_context_len"])
decode_chunk_len = int(meta["decode_chunk_len"])
cnn_module_kernels = to_int_list(meta["cnn_module_kernels"])
num_encoder_layers = to_int_list(meta["num_encoder_layers"])
return MetaData(
model_type=model_type,
attention_dims=attention_dims,
encoder_dims=encoder_dims,
T=T,
left_context_len=left_context_len,
decode_chunk_len=decode_chunk_len,
cnn_module_kernels=cnn_module_kernels,
num_encoder_layers=num_encoder_layers,
)
class RKNNModel:
def __init__(
self, encoder: str, decoder: str, joiner: str, target_platform="rk3588"
):
self.meta = get_meta_data(encoder)
self.encoder = init_model(encoder, custom_string=self.meta.to_str())
self.decoder = init_model(decoder)
self.joiner = init_model(joiner)
def release(self):
self.encoder.release()
self.decoder.release()
self.joiner.release()
def get_init_states(
self,
) -> List[np.ndarray]:
cached_len = []
cached_avg = []
cached_key = []
cached_val = []
cached_val2 = []
cached_conv1 = []
cached_conv2 = []
num_encoder_layers = self.meta.num_encoder_layers
encoder_dims = self.meta.encoder_dims
left_context_len = self.meta.left_context_len
attention_dims = self.meta.attention_dims
cnn_module_kernels = self.meta.cnn_module_kernels
num_encoders = len(num_encoder_layers)
N = 1
for i in range(num_encoders):
cached_len.append(np.zeros((num_encoder_layers[i], N), dtype=np.int64))
cached_avg.append(
np.zeros((num_encoder_layers[i], N, encoder_dims[i]), dtype=np.float32)
)
cached_key.append(
np.zeros(
(num_encoder_layers[i], left_context_len[i], N, attention_dims[i]),
dtype=np.float32,
)
)
cached_val.append(
np.zeros(
(
num_encoder_layers[i],
left_context_len[i],
N,
attention_dims[i] // 2,
),
dtype=np.float32,
)
)
cached_val2.append(
np.zeros(
(
num_encoder_layers[i],
left_context_len[i],
N,
attention_dims[i] // 2,
),
dtype=np.float32,
)
)
cached_conv1.append(
np.zeros(
(
num_encoder_layers[i],
N,
encoder_dims[i],
cnn_module_kernels[i] - 1,
),
dtype=np.float32,
)
)
cached_conv2.append(
np.zeros(
(
num_encoder_layers[i],
N,
encoder_dims[i],
cnn_module_kernels[i] - 1,
),
dtype=np.float32,
)
)
ans = (
cached_len
+ cached_avg
+ cached_key
+ cached_val
+ cached_val2
+ cached_conv1
+ cached_conv2
)
# for i, s in enumerate(ans):
# if s.ndim == 4:
# ans[i] = np.transpose(s, (0, 2, 3, 1))
return ans
def run_encoder(self, x: np.ndarray, states: List[np.ndarray]):
"""
Args:
x: (T, C), np.float32
states: A list of states
"""
x = np.expand_dims(x, axis=0)
out = self.encoder.inference(inputs=[x] + states, data_format="nchw")
# out[0], encoder_out, shape (1, 24, 512)
return out[0], out[1:]
def run_decoder(self, x: np.ndarray):
"""
Args:
x: (1, context_size), np.int64
Returns:
Return decoder_out, (1, C), np.float32
"""
return self.decoder.inference(inputs=[x])[0]
def run_joiner(self, encoder_out: np.ndarray, decoder_out: np.ndarray):
"""
Args:
encoder_out: (1, encoder_out_dim), np.float32
decoder_out: (1, decoder_out_dim), np.float32
Returns:
joiner_out: (1, vocab_size), np.float32
"""
return self.joiner.inference(inputs=[encoder_out, decoder_out])[0]
def main():
args = get_parser().parse_args()
print(vars(args))
id2token = load_tokens(args.tokens)
features = compute_features(args.wav)
model = RKNNModel(
encoder=args.encoder,
decoder=args.decoder,
joiner=args.joiner,
)
print(model.meta)
states = model.get_init_states()
segment = model.meta.T
offset = model.meta.decode_chunk_len
context_size = 2
hyp = [0] * context_size
decoder_input = np.array([hyp], dtype=np.int64)
decoder_out = model.run_decoder(decoder_input)
i = 0
while True:
if i + segment > features.shape[0]:
break
x = features[i : i + segment]
i += offset
encoder_out, states = model.run_encoder(x, states)
encoder_out = encoder_out.squeeze(0) # (1, T, C) -> (T, C)
num_frames = encoder_out.shape[0]
for k in range(num_frames):
joiner_out = model.run_joiner(encoder_out[k : k + 1], decoder_out)
joiner_out = joiner_out.squeeze(0)
max_token_id = joiner_out.argmax()
# assume 0 is the blank id
if max_token_id != 0:
hyp.append(max_token_id)
decoder_input = np.array([hyp[-context_size:]], dtype=np.int64)
decoder_out = model.run_decoder(decoder_input)
print(hyp)
final_hyp = hyp[context_size:]
print(final_hyp)
text = "".join([id2token[i] for i in final_hyp])
text = text.replace("", " ")
print(text)
if __name__ == "__main__":
main()