Export Zipformer2 to RKNN

This commit is contained in:
Fangjun Kuang 2025-03-02 20:53:30 +08:00
parent db9fb8ad31
commit 7d6075b8e0
14 changed files with 776 additions and 52 deletions

View File

@ -12,7 +12,6 @@ log() {
cd egs/librispeech/ASR cd egs/librispeech/ASR
# https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed # https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed
# sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 # sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20
function export_bilingual_zh_en() { function export_bilingual_zh_en() {
@ -124,7 +123,6 @@ function export_bilingual_zh_en_small() {
popd popd
./pruned_transducer_stateless7_streaming/export-onnx-zh.py \ ./pruned_transducer_stateless7_streaming/export-onnx-zh.py \
--dynamic-batch 0 \ --dynamic-batch 0 \
--enable-int8-quantization 0 \ --enable-int8-quantization 0 \

74
.github/scripts/wenetspeech/ASR/run_rknn.sh vendored Executable file
View File

@ -0,0 +1,74 @@
#!/usr/bin/env bash
set -ex
python3 -m pip install kaldi-native-fbank soundfile librosa
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/wenetspeech/ASR
#https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#k2-fsa-icefall-asr-zipformer-wenetspeech-streaming-small-chinese
function export_zh_small() {
d=exp_zh_small
mkdir $d
pushd $d
curl -SL -O https://huggingface.co/k2-fsa/icefall-asr-zipformer-wenetspeech-streaming-small/resolve/main/data/lang_char/tokens.txt
curl -SL -O https://huggingface.co/k2-fsa/icefall-asr-zipformer-wenetspeech-streaming-small/resolve/main/exp/pretrained.pt
mv pretrained.pt epoch-99.pt
curl -SL -o 0.wav https://huggingface.co/k2-fsa/icefall-asr-zipformer-wenetspeech-streaming-small/resolve/main/test_wavs/DEV_T0000000000.wav
curl -SL -o 1.wav https://huggingface.co/k2-fsa/icefall-asr-zipformer-wenetspeech-streaming-small/resolve/main/test_wavs/DEV_T0000000001.wav
curl -SL -o 2.wav https://huggingface.co/k2-fsa/icefall-asr-zipformer-wenetspeech-streaming-small/resolve/main/test_wavs/DEV_T0000000002.wav
ls -lh
popd
./zipformer/export-onnx-streaming.py \
--dynamic-batch 0 \
--enable-int8-quantization 0 \
--tokens $d/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $d \
--use-ctc 0 \
--use-transducer 1 \
\
--num-encoder-layers 2,2,2,2,2,2 \
--feedforward-dim 512,768,768,768,768,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192 \
\
--chunk-size 32 \
--left-context-frames 128 \
--causal 1
out=/icefall/rknn-models-small-wenetspeech
mkdir -p $out
for platform in rk3562 rk3566 rk3568 rk3576 rk3588; do
mkdir -p $out/$platform
./zipformer/export_rknn_transducer_streaming.py \
--in-encoder $d/encoder-epoch-99-avg-1-chunk-32-left-128.onnx \
--in-decoder $d/decoder-epoch-99-avg-1-chunk-32-left-128.onnx \
--in-joiner $d/joiner-epoch-99-avg-1-chunk-32-left-128.onnx \
--out-encoder $out/$platform/encoder.rknn \
--out-decoder $out/$platform/decoder.rknn \
--out-joiner $out/$platform/joiner.rknn \
--target-platform $platform
cp $d/tokens.txt $out/$platform
cp $d/*.wav $out/$platform
ls -lh $out/$platform/
done
ls -h $out
echo "---"
ls -h $out/*
}
export_zh_small

View File

@ -4,7 +4,7 @@ on:
push: push:
branches: branches:
- master - master
- ci-rknn-2 - rknn-zipformer2
pull_request: pull_request:
branches: branches:
@ -31,8 +31,8 @@ jobs:
id: set-matrix id: set-matrix
run: | run: |
# outputting for debugging purposes # outputting for debugging purposes
python ./.github/scripts/docker/generate_build_matrix.py --torch-version=2.4.0 --python-version=3.10 python ./.github/scripts/docker/generate_build_matrix.py --torch-version=2.0.0 --python-version=3.10
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --torch-version=2.4.0 --python-version=3.10) MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --torch-version=2.0.0 --python-version=3.10)
echo "::set-output name=matrix::${MATRIX}" echo "::set-output name=matrix::${MATRIX}"
rknn: rknn:
needs: generate_build_matrix needs: generate_build_matrix
@ -54,7 +54,7 @@ jobs:
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Export ONNX model - name: Export RKNN model
uses: addnab/docker-run-action@v3 uses: addnab/docker-run-action@v3
with: with:
image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }} image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }}
@ -76,17 +76,19 @@ jobs:
# Install rknn # Install rknn
curl -SL -O https://huggingface.co/csukuangfj/rknn-toolkit2/resolve/main/rknn_toolkit2-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl curl -SL -O https://huggingface.co/csukuangfj/rknn-toolkit2/resolve/main/rknn_toolkit2-2.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ./*.whl "numpy<=1.26.4" pip install ./*.whl "numpy<=1.26.4"
pip list | grep rknn pip list | grep rknn
echo "---" echo "---"
pip list pip list
echo "---" echo "---"
.github/scripts/librispeech/ASR/run_rknn.sh .github/scripts/wenetspeech/ASR/run_rknn.sh >log-wenetspeech.txt
# .github/scripts/librispeech/ASR/run_rknn.sh >log-librispeech.txt
- name: Display rknn models - name: Display rknn models (librispeech)
shell: bash shell: bash
if: false
run: | run: |
ls -lh ls -lh
@ -94,7 +96,31 @@ jobs:
echo "----" echo "----"
ls -lh rknn-models-small/* ls -lh rknn-models-small/*
- name: Collect results (small) - name: Display rknn models (wenetspeech)
shell: bash
run: |
ls -lh rknn-models-small-wenetspeech/*
- name: Collect results (small wenetspeech)
shell: bash
run: |
# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#k2-fsa-icefall-asr-zipformer-wenetspeech-streaming-small-chinese
for platform in rk3562 rk3566 rk3568 rk3576 rk3588; do
dst=sherpa-onnx-$platform-streaming-zipformer-small-zh-2025-03-02
mkdir $dst
mkdir $dst/test_wavs
src=rknn-models-small-wenetspeech/$platform
cp -v $src/*.rknn $dst/
cp -v $src/tokens.txt $dst/
cp -v $src/*.wav $dst/test_wavs/
ls -lh $dst
tar cjfv $dst.tar.bz2 $dst
rm -rf $dst
done
- name: Collect results (small librispeech)
if: false
shell: bash shell: bash
run: | run: |
for platform in rk3562 rk3566 rk3568 rk3576 rk3588; do for platform in rk3562 rk3566 rk3568 rk3576 rk3588; do
@ -111,7 +137,8 @@ jobs:
rm -rf $dst rm -rf $dst
done done
- name: Collect results - name: Collect results (librispeech)
if: false
shell: bash shell: bash
run: | run: |
for platform in rk3562 rk3566 rk3568 rk3576 rk3588; do for platform in rk3562 rk3566 rk3568 rk3576 rk3588; do
@ -167,7 +194,6 @@ jobs:
git merge -m "merge remote" --ff origin main git merge -m "merge remote" --ff origin main
dst=streaming-asr dst=streaming-asr
mkdir -p $dst mkdir -p $dst
rm -fv $dst/*
cp ../*rk*.tar.bz2 $dst/ cp ../*rk*.tar.bz2 $dst/
ls -lh $dst ls -lh $dst

View File

@ -72,7 +72,7 @@ def compute_features(filename: str, dim: int = 80) -> np.ndarray:
filename: filename:
Path to an audio file. Path to an audio file.
Returns: Returns:
Return a 1-D float32 tensor of shape (1, 80, 3000) containing the features. Return a 2-D float32 tensor of shape (T, dim) containing the features.
""" """
wave, sample_rate = load_audio(filename) wave, sample_rate = load_audio(filename)
if sample_rate != 16000: if sample_rate != 16000:

View File

@ -74,6 +74,20 @@ def get_parser():
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
) )
parser.add_argument(
"--dynamic-batch",
type=int,
default=1,
help="1 to support dynamic batch size. 0 to support only batch size == 1",
)
parser.add_argument(
"--enable-int8-quantization",
type=int,
default=1,
help="1 to also export int8 onnx models.",
)
parser.add_argument( parser.add_argument(
"--epoch", "--epoch",
type=int, type=int,
@ -270,6 +284,7 @@ def export_streaming_ctc_model_onnx(
model: OnnxModel, model: OnnxModel,
encoder_filename: str, encoder_filename: str,
opset_version: int = 11, opset_version: int = 11,
dynamic_batch: bool = True,
) -> None: ) -> None:
model.encoder.__class__.forward = model.encoder.__class__.streaming_forward model.encoder.__class__.forward = model.encoder.__class__.streaming_forward
@ -408,7 +423,9 @@ def export_streaming_ctc_model_onnx(
"log_probs": {0: "N"}, "log_probs": {0: "N"},
**inputs, **inputs,
**outputs, **outputs,
}, }
if dynamic_batch
else {},
) )
add_meta_data(filename=encoder_filename, meta_data=meta_data) add_meta_data(filename=encoder_filename, meta_data=meta_data)
@ -547,9 +564,11 @@ def main():
model, model,
model_filename, model_filename,
opset_version=opset_version, opset_version=opset_version,
dynamic_batch=params.dynamic_batch == 1,
) )
logging.info(f"Exported model to {model_filename}") logging.info(f"Exported model to {model_filename}")
if params.enable_int8_quantization:
# Generate int8 quantization models # Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection

View File

@ -93,6 +93,20 @@ def get_parser():
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
) )
parser.add_argument(
"--dynamic-batch",
type=int,
default=1,
help="1 to support dynamic batch size. 0 to support only batch size == 1",
)
parser.add_argument(
"--enable-int8-quantization",
type=int,
default=1,
help="1 to also export int8 onnx models.",
)
parser.add_argument( parser.add_argument(
"--epoch", "--epoch",
type=int, type=int,
@ -342,6 +356,7 @@ def export_encoder_model_onnx(
encoder_filename: str, encoder_filename: str,
opset_version: int = 11, opset_version: int = 11,
feature_dim: int = 80, feature_dim: int = 80,
dynamic_batch: bool = True,
) -> None: ) -> None:
encoder_model.encoder.__class__.forward = ( encoder_model.encoder.__class__.forward = (
encoder_model.encoder.__class__.streaming_forward encoder_model.encoder.__class__.streaming_forward
@ -482,7 +497,9 @@ def export_encoder_model_onnx(
"encoder_out": {0: "N"}, "encoder_out": {0: "N"},
**inputs, **inputs,
**outputs, **outputs,
}, }
if dynamic_batch
else {},
) )
add_meta_data(filename=encoder_filename, meta_data=meta_data) add_meta_data(filename=encoder_filename, meta_data=meta_data)
@ -492,6 +509,7 @@ def export_decoder_model_onnx(
decoder_model: OnnxDecoder, decoder_model: OnnxDecoder,
decoder_filename: str, decoder_filename: str,
opset_version: int = 11, opset_version: int = 11,
dynamic_batch: bool = True,
) -> None: ) -> None:
"""Export the decoder model to ONNX format. """Export the decoder model to ONNX format.
@ -514,7 +532,7 @@ def export_decoder_model_onnx(
context_size = decoder_model.decoder.context_size context_size = decoder_model.decoder.context_size
vocab_size = decoder_model.decoder.vocab_size vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64) y = torch.zeros(1, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model) decoder_model = torch.jit.script(decoder_model)
torch.onnx.export( torch.onnx.export(
decoder_model, decoder_model,
@ -527,7 +545,9 @@ def export_decoder_model_onnx(
dynamic_axes={ dynamic_axes={
"y": {0: "N"}, "y": {0: "N"},
"decoder_out": {0: "N"}, "decoder_out": {0: "N"},
}, }
if dynamic_batch
else {},
) )
meta_data = { meta_data = {
@ -541,6 +561,7 @@ def export_joiner_model_onnx(
joiner_model: nn.Module, joiner_model: nn.Module,
joiner_filename: str, joiner_filename: str,
opset_version: int = 11, opset_version: int = 11,
dynamic_batch: bool = True,
) -> None: ) -> None:
"""Export the joiner model to ONNX format. """Export the joiner model to ONNX format.
The exported joiner model has two inputs: The exported joiner model has two inputs:
@ -555,8 +576,8 @@ def export_joiner_model_onnx(
joiner_dim = joiner_model.output_linear.weight.shape[1] joiner_dim = joiner_model.output_linear.weight.shape[1]
logging.info(f"joiner dim: {joiner_dim}") logging.info(f"joiner dim: {joiner_dim}")
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
torch.onnx.export( torch.onnx.export(
joiner_model, joiner_model,
@ -573,7 +594,9 @@ def export_joiner_model_onnx(
"encoder_out": {0: "N"}, "encoder_out": {0: "N"},
"decoder_out": {0: "N"}, "decoder_out": {0: "N"},
"logit": {0: "N"}, "logit": {0: "N"},
}, }
if dynamic_batch
else {},
) )
meta_data = { meta_data = {
"joiner_dim": str(joiner_dim), "joiner_dim": str(joiner_dim),
@ -734,6 +757,7 @@ def main():
encoder_filename, encoder_filename,
opset_version=opset_version, opset_version=opset_version,
feature_dim=params.feature_dim, feature_dim=params.feature_dim,
dynamic_batch=params.dynamic_batch == 1,
) )
logging.info(f"Exported encoder to {encoder_filename}") logging.info(f"Exported encoder to {encoder_filename}")
@ -743,6 +767,7 @@ def main():
decoder, decoder,
decoder_filename, decoder_filename,
opset_version=opset_version, opset_version=opset_version,
dynamic_batch=params.dynamic_batch == 1,
) )
logging.info(f"Exported decoder to {decoder_filename}") logging.info(f"Exported decoder to {decoder_filename}")
@ -752,6 +777,7 @@ def main():
joiner, joiner,
joiner_filename, joiner_filename,
opset_version=opset_version, opset_version=opset_version,
dynamic_batch=params.dynamic_batch == 1,
) )
logging.info(f"Exported joiner to {joiner_filename}") logging.info(f"Exported joiner to {joiner_filename}")
@ -778,6 +804,7 @@ def main():
# Generate int8 quantization models # Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
if params.enable_int8_quantization:
logging.info("Generate int8 quantization models") logging.info("Generate int8 quantization models")
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"

View File

@ -0,0 +1,74 @@
#!/usr/bin/env python3
# Copyright (c) 2025 Xiaomi Corporation (authors: Fangjun Kuang)
import argparse
import logging
from pathlib import Path
from typing import List
from rknn.api import RKNN
from test_rknn_on_cpu_simulator_ctc_streaming import RKNNModel
logging.basicConfig(level=logging.WARNING)
g_platforms = [
# "rv1103",
# "rv1103b",
# "rv1106",
# "rk2118",
"rk3562",
"rk3566",
"rk3568",
"rk3576",
"rk3588",
]
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--target-platform",
type=str,
required=True,
help=f"Supported values are: {','.join(g_platforms)}",
)
parser.add_argument(
"--in-model",
type=str,
required=True,
help="Path to the onnx model",
)
parser.add_argument(
"--out-model",
type=str,
required=True,
help="Path to the rknn model",
)
return parser
def main():
args = get_parser().parse_args()
print(vars(args))
model = RKNNModel(
model=args.in_model,
target_platform=args.target_platform,
)
print(model.meta)
model.export_rknn(
model=args.out_model,
)
model.release()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,139 @@
#!/usr/bin/env python3
# Copyright (c) 2025 Xiaomi Corporation (authors: Fangjun Kuang)
import argparse
import logging
from pathlib import Path
from typing import List
from rknn.api import RKNN
from test_rknn_on_cpu_simulator_ctc_streaming import (
MetaData,
get_meta_data,
init_model,
export_rknn,
)
logging.basicConfig(level=logging.WARNING)
g_platforms = [
# "rv1103",
# "rv1103b",
# "rv1106",
# "rk2118",
"rk3562",
"rk3566",
"rk3568",
"rk3576",
"rk3588",
]
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--target-platform",
type=str,
required=True,
help=f"Supported values are: {','.join(g_platforms)}",
)
parser.add_argument(
"--in-encoder",
type=str,
required=True,
help="Path to the encoder onnx model",
)
parser.add_argument(
"--in-decoder",
type=str,
required=True,
help="Path to the decoder onnx model",
)
parser.add_argument(
"--in-joiner",
type=str,
required=True,
help="Path to the joiner onnx model",
)
parser.add_argument(
"--out-encoder",
type=str,
required=True,
help="Path to the encoder rknn model",
)
parser.add_argument(
"--out-decoder",
type=str,
required=True,
help="Path to the decoder rknn model",
)
parser.add_argument(
"--out-joiner",
type=str,
required=True,
help="Path to the joiner rknn model",
)
return parser
class RKNNModel:
def __init__(
self,
encoder: str,
decoder: str,
joiner: str,
target_platform: str,
):
self.meta = get_meta_data(encoder)
self.encoder = init_model(
encoder,
custom_string=self.meta.to_str(),
target_platform=target_platform,
)
self.decoder = init_model(decoder, target_platform=target_platform)
self.joiner = init_model(joiner, target_platform=target_platform)
def export_rknn(self, encoder, decoder, joiner):
export_rknn(self.encoder, encoder)
export_rknn(self.decoder, decoder)
export_rknn(self.joiner, joiner)
def release(self):
self.encoder.release()
self.decoder.release()
self.joiner.release()
def main():
args = get_parser().parse_args()
print(vars(args))
model = RKNNModel(
encoder=args.in_encoder,
decoder=args.in_decoder,
joiner=args.in_joiner,
target_platform=args.target_platform,
)
print(model.meta)
model.export_rknn(
encoder=args.out_encoder,
decoder=args.out_decoder,
joiner=args.out_joiner,
)
model.release()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,362 @@
#!/usr/bin/env python3
# Copyright (c) 2025 Xiaomi Corporation (authors: Fangjun Kuang)
import argparse
from pathlib import Path
from typing import List, Tuple
import kaldi_native_fbank as knf
import numpy as np
import soundfile as sf
from rknn.api import RKNN
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--model",
type=str,
required=True,
help="Path to the onnx model",
)
parser.add_argument(
"--tokens",
type=str,
required=True,
help="Path to the tokens.txt",
)
parser.add_argument(
"--wav",
type=str,
required=True,
help="Path to test wave",
)
return parser
def load_audio(filename: str) -> Tuple[np.ndarray, int]:
data, sample_rate = sf.read(
filename,
always_2d=True,
dtype="float32",
)
data = data[:, 0] # use only the first channel
samples = np.ascontiguousarray(data)
return samples, sample_rate
def compute_features(filename: str, dim: int = 80) -> np.ndarray:
"""
Args:
filename:
Path to an audio file.
Returns:
Return a 2-D float32 tensor of shape (T, dim) containing the features.
"""
wave, sample_rate = load_audio(filename)
if sample_rate != 16000:
import librosa
wave = librosa.resample(wave, orig_sr=sample_rate, target_sr=16000)
sample_rate = 16000
features = []
opts = knf.FbankOptions()
opts.frame_opts.dither = 0
opts.mel_opts.num_bins = dim
opts.frame_opts.snip_edges = False
fbank = knf.OnlineFbank(opts)
fbank.accept_waveform(16000, wave)
tail_paddings = np.zeros(int(0.5 * 16000), dtype=np.float32)
fbank.accept_waveform(16000, tail_paddings)
fbank.input_finished()
for i in range(fbank.num_frames_ready):
f = fbank.get_frame(i)
features.append(f)
features = np.stack(features, axis=0)
return features
def load_tokens(filename):
tokens = dict()
with open(filename, "r") as f:
for line in f:
t, i = line.split()
tokens[int(i)] = t
return tokens
def init_model(filename, target_platform="rk3588", custom_string=None):
rknn = RKNN(verbose=False)
rknn.config(target_platform=target_platform, custom_string=custom_string)
if not Path(filename).is_file():
exit(f"{filename} does not exist")
ret = rknn.load_onnx(model=filename)
if ret != 0:
exit(f"Load model {filename} failed!")
ret = rknn.build(do_quantization=False)
if ret != 0:
exit("Build model {filename} failed!")
ret = rknn.init_runtime()
if ret != 0:
exit(f"Failed to init rknn runtime for {filename}")
return rknn
class MetaData:
def __init__(
self,
model_type: str,
decode_chunk_len: int,
T: int,
num_encoder_layers: List[int],
encoder_dims: List[int],
cnn_module_kernels: List[int],
left_context_len: List[int],
query_head_dims: List[int],
value_head_dims: List[int],
num_heads: List[int],
):
self.model_type = model_type
self.decode_chunk_len = decode_chunk_len
self.T = T
self.num_encoder_layers = num_encoder_layers
self.encoder_dims = encoder_dims
self.cnn_module_kernels = cnn_module_kernels
self.left_context_len = left_context_len
self.query_head_dims = query_head_dims
self.value_head_dims = value_head_dims
self.num_heads = num_heads
def __str__(self) -> str:
return self.to_str()
def to_str(self) -> str:
def to_s(ll):
return ",".join(list(map(str, ll)))
s = f"model_type={self.model_type}"
s += ";decode_chunk_len=" + str(self.decode_chunk_len)
s += ";T=" + str(self.T)
s += ";num_encoder_layers=" + to_s(self.num_encoder_layers)
s += ";encoder_dims=" + to_s(self.encoder_dims)
s += ";cnn_module_kernels=" + to_s(self.cnn_module_kernels)
s += ";left_context_len=" + to_s(self.left_context_len)
s += ";query_head_dims=" + to_s(self.query_head_dims)
s += ";value_head_dims=" + to_s(self.value_head_dims)
s += ";num_heads=" + to_s(self.num_heads)
assert len(s) < 1024, (s, len(s))
return s
def get_meta_data(model: str):
import onnxruntime
session_opts = onnxruntime.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
m = onnxruntime.InferenceSession(
model,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)
for i in m.get_inputs():
print(i)
print("-----")
for i in m.get_outputs():
print(i)
meta = m.get_modelmeta().custom_metadata_map
print(meta)
"""
{'num_heads': '4,4,4,8,4,4', 'query_head_dims': '32,32,32,32,32,32',
'cnn_module_kernels': '31,31,15,15,15,31',
'num_encoder_layers': '2,2,3,4,3,2', ' version': '1',
'comment': 'streaming ctc zipformer2',
'model_type': 'zipformer2',
'encoder_dims': '192,256,384,512,384,256',
'model_author': 'k2-fsa', 'T': '77',
'value_head_dims': '12,12,12,12,12,12',
'left_context_len': '128,64,32,16,32,64',
'decode_chunk_len': '64'}
"""
def to_int_list(s):
return list(map(int, s.split(",")))
model_type = meta["model_type"]
decode_chunk_len = int(meta["decode_chunk_len"])
T = int(meta["T"])
num_encoder_layers = to_int_list(meta["num_encoder_layers"])
encoder_dims = to_int_list(meta["encoder_dims"])
cnn_module_kernels = to_int_list(meta["cnn_module_kernels"])
left_context_len = to_int_list(meta["left_context_len"])
query_head_dims = to_int_list(meta["query_head_dims"])
value_head_dims = to_int_list(meta["value_head_dims"])
num_heads = to_int_list(meta["num_heads"])
return MetaData(
model_type=model_type,
decode_chunk_len=decode_chunk_len,
T=T,
num_encoder_layers=num_encoder_layers,
encoder_dims=encoder_dims,
cnn_module_kernels=cnn_module_kernels,
left_context_len=left_context_len,
query_head_dims=query_head_dims,
value_head_dims=value_head_dims,
num_heads=num_heads,
)
def export_rknn(rknn, filename):
ret = rknn.export_rknn(filename)
if ret != 0:
exit("Export rknn model to {filename} failed!")
class RKNNModel:
def __init__(self, model: str, target_platform="rk3588"):
self.meta = get_meta_data(model)
self.model = init_model(model, custom_string=self.meta.to_str())
def export_rknn(self, model: str):
export_rknn(self.model, model)
def release(self):
self.model.release()
def get_init_states(
self,
) -> List[np.ndarray]:
states = []
num_encoder_layers = self.meta.num_encoder_layers
encoder_dims = self.meta.encoder_dims
left_context_len = self.meta.left_context_len
cnn_module_kernels = self.meta.cnn_module_kernels
query_head_dims = self.meta.query_head_dims
value_head_dims = self.meta.value_head_dims
num_heads = self.meta.num_heads
num_encoders = len(num_encoder_layers)
N = 1
for i in range(num_encoders):
num_layers = num_encoder_layers[i]
key_dim = query_head_dims[i] * num_heads[i]
embed_dim = encoder_dims[i]
nonlin_attn_head_dim = 3 * embed_dim // 4
value_dim = value_head_dims[i] * num_heads[i]
conv_left_pad = cnn_module_kernels[i] // 2
for layer in range(num_layers):
cached_key = np.zeros(
(left_context_len[i], N, key_dim), dtype=np.float32
)
cached_nonlin_attn = np.zeros(
(1, N, left_context_len[i], nonlin_attn_head_dim),
dtype=np.float32,
)
cached_val1 = np.zeros(
(left_context_len[i], N, value_dim),
dtype=np.float32,
)
cached_val2 = np.zeros(
(left_context_len[i], N, value_dim),
dtype=np.float32,
)
cached_conv1 = np.zeros((N, embed_dim, conv_left_pad), dtype=np.float32)
cached_conv2 = np.zeros((N, embed_dim, conv_left_pad), dtype=np.float32)
states += [
cached_key,
cached_nonlin_attn,
cached_val1,
cached_val2,
cached_conv1,
cached_conv2,
]
embed_states = np.zeros((N, 128, 3, 19), dtype=np.float32)
states.append(embed_states)
processed_lens = np.zeros((N,), dtype=np.int64)
states.append(processed_lens)
return states
def run_model(self, x: np.ndarray, states: List[np.ndarray]):
"""
Args:
x: (T, C), np.float32
states: A list of states
"""
x = np.expand_dims(x, axis=0)
out = self.model.inference(inputs=[x] + states, data_format="nchw")
# out[0]: log_probs, (N, T, C)
return out[0], out[1:]
def main():
args = get_parser().parse_args()
print(vars(args))
id2token = load_tokens(args.tokens)
features = compute_features(args.wav)
model = RKNNModel(
model=args.model,
)
print(model.meta)
states = model.get_init_states()
segment = model.meta.T
offset = model.meta.decode_chunk_len
ans = []
blank = 0
prev = -1
i = 0
while True:
if i + segment > features.shape[0]:
break
x = features[i : i + segment]
i += offset
log_probs, states = model.run_model(x, states)
log_probs = log_probs[0] # (N, T, C) -> (N, T, C)
ids = log_probs.argmax(axis=1)
for k in ids:
if i != blank and i != prev:
ans.append(i)
prev = i
tokens = [id2token[i] for i in ans]
underline = ""
# underline = b"\xe2\x96\x81".decode()
text = "".join(tokens).replace(underline, " ").strip()
print(ans)
print(args.wav)
print(text)
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/export_rknn_ctc_streaming.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/test_rknn_on_cpu_simulator_ctc_streaming.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/export_rknn_transducer_streaming.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/test_rknn_on_cpu_simulator_ctc.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/test_rknn_on_cpu_simulator_ctc_streaming.py