use docker

This commit is contained in:
Fangjun Kuang 2025-02-25 16:46:02 +08:00
parent 6bd0c45a8d
commit 61933af28f
5 changed files with 226 additions and 33 deletions

View File

@ -10,7 +10,17 @@ def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--min-torch-version",
help="Minimu torch version",
help="torch version",
)
parser.add_argument(
"--torch-version",
help="torch version",
)
parser.add_argument(
"--python-version",
help="python version",
)
return parser.parse_args()
@ -52,7 +62,7 @@ def get_torchaudio_version(torch_version):
return torch_version
def get_matrix(min_torch_version):
def get_matrix(min_torch_version, specified_torch_version, specified_python_version):
k2_version = "1.24.4.dev20241029"
kaldifeat_version = "1.25.5.dev20241029"
version = "20241218"
@ -71,6 +81,12 @@ def get_matrix(min_torch_version):
torch_version += ["2.5.0"]
torch_version += ["2.5.1"]
if specified_torch_version:
torch_version = [specified_torch_version]
if specified_python_version:
python_version = [specified_python_version]
matrix = []
for p in python_version:
for t in torch_version:
@ -115,7 +131,11 @@ def get_matrix(min_torch_version):
def main():
args = get_args()
matrix = get_matrix(min_torch_version=args.min_torch_version)
matrix = get_matrix(
min_torch_version=args.min_torch_version,
specified_torch_version=args.torch_version,
specified_python_version=args.python_version,
)
print(json.dumps({"include": matrix}))

69
.github/scripts/librispeech/ASR/run_rknn.sh vendored Executable file
View File

@ -0,0 +1,69 @@
#!/usr/bin/env bash
set -ex
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/librispeech/ASR
# https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed
function export_bilingual_zh_en() {
d=exp_zh_en
mkdir $d
pushd $d
curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/exp/pretrained.pt
mv pretrained.pt epoch-99.pt
curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/data/lang_char_bpe/tokens.txt
curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/data/lang_char_bpe/bpe.model
curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/test_wavs/0.wav
curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/test_wavs/BAC009S0764W0164.wav
ls -lh
popd
./pruned_transducer_stateless7_streaming/export-onnx-zh.py \
--dynamic-batch 0 \
--enable-int8-quantization 0 \
--tokens $d/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $d/ \
--decode-chunk-len 32 \
--num-encoder-layers "2,4,3,2,4" \
--feedforward-dims "1024,1024,1536,1536,1024" \
--nhead "8,8,8,8,8" \
--encoder-dims "384,384,384,384,384" \
--attention-dims "192,192,192,192,192" \
--encoder-unmasked-dims "256,256,256,256,256" \
--zipformer-downsampling-factors "1,2,4,8,2" \
--cnn-module-kernels "31,31,31,31,31" \
--decoder-dim 512 \
--joiner-dim 512
./pruned_transducer_stateless7_streaming/onnx_pretrained.py \
--encoder-model-filename $d/encoder-epoch-99-avg-1.onnx \
--decoder-model-filename $d/decoder-epoch-99-avg-1.onnx \
--joiner-model-filename $d/joiner-epoch-99-avg-1.onnx \
--tokens $d/tokens.txt \
$d/0.wav
./pruned_transducer_stateless7_streaming/onnx_pretrained.py \
--encoder-model-filename $d/encoder-epoch-99-avg-1.onnx \
--decoder-model-filename $d/decoder-epoch-99-avg-1.onnx \
--joiner-model-filename $d/joiner-epoch-99-avg-1.onnx \
--tokens $d/tokens.txt \
$d/BAC009S0764W0164.wav
ls -lh $d/
}
export_bilingual_zh_en

View File

@ -17,24 +17,75 @@ concurrency:
cancel-in-progress: true
jobs:
generate_build_matrix:
if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa'
# see https://github.com/pytorch/pytorch/pull/50633
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Generating build matrix
id: set-matrix
run: |
# outputting for debugging purposes
python ./.github/scripts/docker/generate_build_matrix.py --torch-version=2.4.0 --python-version=3.10
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --torch-version=2.4.0 --python-version=3.10)
echo "::set-output name=matrix::${MATRIX}"
rknn:
needs: generate_build_matrix
name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python-version: ["3.10"]
${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Python
if: false
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Export ONNX model
uses: addnab/docker-run-action@v3
with:
image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }}
options: |
--volume ${{ github.workspace }}/:/icefall
shell: bash
run: |
cat /etc/*release
lsb_release -a
uname -a
python3 --version
export PYTHONPATH=/icefall:$PYTHONPATH
cd /icefall
git config --global --add safe.directory /icefall
python3 -m torch.utils.collect_env
python3 -m k2.version
pip list
.github/scripts/librispeech/ASR/run_rknn.sh
# Install rknn
curl -SL -O https://huggingface.co/csukuangfj/rknn-toolkit2/resolve/main/rknn_toolkit2-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
pip install ./*.whl "numpy<=1.26.4"
pip list | grep rknn
echo "---"
pip list
echo "---"
- name: Install rknn
if: false
shell: bash
run: |
ls
@ -46,6 +97,7 @@ jobs:
echo "---"
- name: Run
if: false
shell: bash
run: |
ls
@ -74,11 +126,13 @@ jobs:
ls -lh ../model
- uses: actions/upload-artifact@v4
if: false
with:
name: rknn-files
path: ./rknn_model_zoo/examples/zipformer/model/*.rknn
- uses: actions/upload-artifact@v4
if: false
with:
name: onnx-files
path: ./rknn_model_zoo/examples/zipformer/model/*.onnx

View File

@ -85,6 +85,20 @@ def get_parser():
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--dynamic-batch",
type=int,
default=1,
help="1 to support dynamic batch size. 0 to support only batch size == 1",
)
parser.add_argument(
"--enable-int8-quantization",
type=int,
default=1,
help="1 to also export int8 onnx models.",
)
parser.add_argument(
"--epoch",
type=int,
@ -257,6 +271,7 @@ def export_encoder_model_onnx(
encoder_model: OnnxEncoder,
encoder_filename: str,
opset_version: int = 11,
dynamic_batch: bool = True,
) -> None:
"""
Onnx model inputs:
@ -274,6 +289,8 @@ def export_encoder_model_onnx(
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
dynamic_batch:
True to export a model supporting dynamic batch size
"""
encoder_model.encoder.__class__.forward = (
@ -379,7 +396,9 @@ def export_encoder_model_onnx(
"encoder_out": {0: "N"},
**inputs,
**outputs,
},
}
if dynamic_batch
else {},
)
add_meta_data(filename=encoder_filename, meta_data=meta_data)
@ -389,6 +408,7 @@ def export_decoder_model_onnx(
decoder_model: nn.Module,
decoder_filename: str,
opset_version: int = 11,
dynamic_batch: bool = True,
) -> None:
"""Export the decoder model to ONNX format.
@ -412,7 +432,7 @@ def export_decoder_model_onnx(
"""
context_size = decoder_model.decoder.context_size
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
y = torch.zeros(1, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
@ -425,7 +445,9 @@ def export_decoder_model_onnx(
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
}
if dynamic_batch
else {},
)
meta_data = {
"context_size": str(context_size),
@ -438,6 +460,7 @@ def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
dynamic_batch: bool = True,
) -> None:
"""Export the joiner model to ONNX format.
The exported joiner model has two inputs:
@ -452,8 +475,8 @@ def export_joiner_model_onnx(
joiner_dim = joiner_model.output_linear.weight.shape[1]
logging.info(f"joiner dim: {joiner_dim}")
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model,
@ -470,7 +493,9 @@ def export_joiner_model_onnx(
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
}
if dynamic_batch
else {},
)
meta_data = {
"joiner_dim": str(joiner_dim),
@ -629,6 +654,7 @@ def main():
encoder,
encoder_filename,
opset_version=opset_version,
dynamic_batch=params.dynamic_batch == 1,
)
logging.info(f"Exported encoder to {encoder_filename}")
@ -638,6 +664,7 @@ def main():
decoder,
decoder_filename,
opset_version=opset_version,
dynamic_batch=params.dynamic_batch == 1,
)
logging.info(f"Exported decoder to {decoder_filename}")
@ -647,37 +674,39 @@ def main():
joiner,
joiner_filename,
opset_version=opset_version,
dynamic_batch=params.dynamic_batch == 1,
)
logging.info(f"Exported joiner to {joiner_filename}")
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
logging.info("Generate int8 quantization models")
if params.enable_int8_quantization:
logging.info("Generate int8 quantization models")
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=encoder_filename,
model_output=encoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=encoder_filename,
model_output=encoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul", "Gather"],
weight_type=QuantType.QInt8,
)
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul", "Gather"],
weight_type=QuantType.QInt8,
)
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
quantize_dynamic(
model_input=joiner_filename,
model_output=joiner_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
quantize_dynamic(
model_input=joiner_filename,
model_output=joiner_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
if __name__ == "__main__":

View File

@ -132,10 +132,18 @@ class OnnxModel:
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
print("==========Encoder input==========")
for i in self.encoder.get_inputs():
print(i)
print("==========Encoder output==========")
for i in self.encoder.get_outputs():
print(i)
self.init_encoder_states()
def init_encoder_states(self, batch_size: int = 1):
encoder_meta = self.encoder.get_modelmeta().custom_metadata_map
print(encoder_meta)
model_type = encoder_meta["model_type"]
assert model_type == "zipformer", model_type
@ -232,6 +240,12 @@ class OnnxModel:
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
print("==========Decoder input==========")
for i in self.decoder.get_inputs():
print(i)
print("==========Decoder output==========")
for i in self.decoder.get_outputs():
print(i)
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
self.context_size = int(decoder_meta["context_size"])
@ -247,6 +261,13 @@ class OnnxModel:
providers=["CPUExecutionProvider"],
)
print("==========Joiner input==========")
for i in self.joiner.get_inputs():
print(i)
print("==========Joiner output==========")
for i in self.joiner.get_outputs():
print(i)
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
self.joiner_dim = int(joiner_meta["joiner_dim"])