diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py index a516a53c5..638e19498 100755 --- a/.github/scripts/docker/generate_build_matrix.py +++ b/.github/scripts/docker/generate_build_matrix.py @@ -10,7 +10,17 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--min-torch-version", - help="Minimu torch version", + help="torch version", + ) + + parser.add_argument( + "--torch-version", + help="torch version", + ) + + parser.add_argument( + "--python-version", + help="python version", ) return parser.parse_args() @@ -52,7 +62,7 @@ def get_torchaudio_version(torch_version): return torch_version -def get_matrix(min_torch_version): +def get_matrix(min_torch_version, specified_torch_version, specified_python_version): k2_version = "1.24.4.dev20241029" kaldifeat_version = "1.25.5.dev20241029" version = "20241218" @@ -71,6 +81,12 @@ def get_matrix(min_torch_version): torch_version += ["2.5.0"] torch_version += ["2.5.1"] + if specified_torch_version: + torch_version = [specified_torch_version] + + if specified_python_version: + python_version = [specified_python_version] + matrix = [] for p in python_version: for t in torch_version: @@ -115,7 +131,11 @@ def get_matrix(min_torch_version): def main(): args = get_args() - matrix = get_matrix(min_torch_version=args.min_torch_version) + matrix = get_matrix( + min_torch_version=args.min_torch_version, + specified_torch_version=args.torch_version, + specified_python_version=args.python_version, + ) print(json.dumps({"include": matrix})) diff --git a/.github/scripts/librispeech/ASR/run_rknn.sh b/.github/scripts/librispeech/ASR/run_rknn.sh new file mode 100755 index 000000000..32a150ef1 --- /dev/null +++ b/.github/scripts/librispeech/ASR/run_rknn.sh @@ -0,0 +1,69 @@ +#!/usr/bin/env bash + +set -ex + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + + +# https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed +function export_bilingual_zh_en() { + d=exp_zh_en + + mkdir $d + pushd $d + + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/exp/pretrained.pt + mv pretrained.pt epoch-99.pt + + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/data/lang_char_bpe/tokens.txt + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/data/lang_char_bpe/bpe.model + + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/test_wavs/0.wav + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/test_wavs/BAC009S0764W0164.wav + ls -lh + popd + + ./pruned_transducer_stateless7_streaming/export-onnx-zh.py \ + --dynamic-batch 0 \ + --enable-int8-quantization 0 \ + --tokens $d/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $d/ \ + --decode-chunk-len 32 \ + --num-encoder-layers "2,4,3,2,4" \ + --feedforward-dims "1024,1024,1536,1536,1024" \ + --nhead "8,8,8,8,8" \ + --encoder-dims "384,384,384,384,384" \ + --attention-dims "192,192,192,192,192" \ + --encoder-unmasked-dims "256,256,256,256,256" \ + --zipformer-downsampling-factors "1,2,4,8,2" \ + --cnn-module-kernels "31,31,31,31,31" \ + --decoder-dim 512 \ + --joiner-dim 512 + + ./pruned_transducer_stateless7_streaming/onnx_pretrained.py \ + --encoder-model-filename $d/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $d/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $d/joiner-epoch-99-avg-1.onnx \ + --tokens $d/tokens.txt \ + $d/0.wav + + ./pruned_transducer_stateless7_streaming/onnx_pretrained.py \ + --encoder-model-filename $d/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $d/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $d/joiner-epoch-99-avg-1.onnx \ + --tokens $d/tokens.txt \ + $d/BAC009S0764W0164.wav + + ls -lh $d/ +} + +export_bilingual_zh_en diff --git a/.github/workflows/rknn.yml b/.github/workflows/rknn.yml index 6c974a318..e1877e833 100644 --- a/.github/workflows/rknn.yml +++ b/.github/workflows/rknn.yml @@ -17,24 +17,75 @@ concurrency: cancel-in-progress: true jobs: + generate_build_matrix: + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' + # see https://github.com/pytorch/pytorch/pull/50633 + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Generating build matrix + id: set-matrix + run: | + # outputting for debugging purposes + python ./.github/scripts/docker/generate_build_matrix.py --torch-version=2.4.0 --python-version=3.10 + MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --torch-version=2.4.0 --python-version=3.10) + echo "::set-output name=matrix::${MATRIX}" rknn: + needs: generate_build_matrix + name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }} runs-on: ubuntu-latest strategy: fail-fast: false matrix: - os: [ubuntu-latest] - python-version: ["3.10"] + ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} + steps: - uses: actions/checkout@v4 with: fetch-depth: 0 - name: Setup Python + if: false uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} + - name: Export ONNX model + uses: addnab/docker-run-action@v3 + with: + image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }} + options: | + --volume ${{ github.workspace }}/:/icefall + shell: bash + run: | + cat /etc/*release + lsb_release -a + uname -a + python3 --version + export PYTHONPATH=/icefall:$PYTHONPATH + cd /icefall + git config --global --add safe.directory /icefall + + python3 -m torch.utils.collect_env + python3 -m k2.version + pip list + + .github/scripts/librispeech/ASR/run_rknn.sh + + # Install rknn + curl -SL -O https://huggingface.co/csukuangfj/rknn-toolkit2/resolve/main/rknn_toolkit2-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + pip install ./*.whl "numpy<=1.26.4" + pip list | grep rknn + echo "---" + pip list + echo "---" + - name: Install rknn + if: false shell: bash run: | ls @@ -46,6 +97,7 @@ jobs: echo "---" - name: Run + if: false shell: bash run: | ls @@ -74,11 +126,13 @@ jobs: ls -lh ../model - uses: actions/upload-artifact@v4 + if: false with: name: rknn-files path: ./rknn_model_zoo/examples/zipformer/model/*.rknn - uses: actions/upload-artifact@v4 + if: false with: name: onnx-files path: ./rknn_model_zoo/examples/zipformer/model/*.onnx diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py index 2de56837e..a4fbd93ba 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py @@ -85,6 +85,20 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter ) + parser.add_argument( + "--dynamic-batch", + type=int, + default=1, + help="1 to support dynamic batch size. 0 to support only batch size == 1", + ) + + parser.add_argument( + "--enable-int8-quantization", + type=int, + default=1, + help="1 to also export int8 onnx models.", + ) + parser.add_argument( "--epoch", type=int, @@ -257,6 +271,7 @@ def export_encoder_model_onnx( encoder_model: OnnxEncoder, encoder_filename: str, opset_version: int = 11, + dynamic_batch: bool = True, ) -> None: """ Onnx model inputs: @@ -274,6 +289,8 @@ def export_encoder_model_onnx( The filename to save the exported ONNX model. opset_version: The opset version to use. + dynamic_batch: + True to export a model supporting dynamic batch size """ encoder_model.encoder.__class__.forward = ( @@ -379,7 +396,9 @@ def export_encoder_model_onnx( "encoder_out": {0: "N"}, **inputs, **outputs, - }, + } + if dynamic_batch + else {}, ) add_meta_data(filename=encoder_filename, meta_data=meta_data) @@ -389,6 +408,7 @@ def export_decoder_model_onnx( decoder_model: nn.Module, decoder_filename: str, opset_version: int = 11, + dynamic_batch: bool = True, ) -> None: """Export the decoder model to ONNX format. @@ -412,7 +432,7 @@ def export_decoder_model_onnx( """ context_size = decoder_model.decoder.context_size vocab_size = decoder_model.decoder.vocab_size - y = torch.zeros(10, context_size, dtype=torch.int64) + y = torch.zeros(1, context_size, dtype=torch.int64) decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, @@ -425,7 +445,9 @@ def export_decoder_model_onnx( dynamic_axes={ "y": {0: "N"}, "decoder_out": {0: "N"}, - }, + } + if dynamic_batch + else {}, ) meta_data = { "context_size": str(context_size), @@ -438,6 +460,7 @@ def export_joiner_model_onnx( joiner_model: nn.Module, joiner_filename: str, opset_version: int = 11, + dynamic_batch: bool = True, ) -> None: """Export the joiner model to ONNX format. The exported joiner model has two inputs: @@ -452,8 +475,8 @@ def export_joiner_model_onnx( joiner_dim = joiner_model.output_linear.weight.shape[1] logging.info(f"joiner dim: {joiner_dim}") - projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) torch.onnx.export( joiner_model, @@ -470,7 +493,9 @@ def export_joiner_model_onnx( "encoder_out": {0: "N"}, "decoder_out": {0: "N"}, "logit": {0: "N"}, - }, + } + if dynamic_batch + else {}, ) meta_data = { "joiner_dim": str(joiner_dim), @@ -629,6 +654,7 @@ def main(): encoder, encoder_filename, opset_version=opset_version, + dynamic_batch=params.dynamic_batch == 1, ) logging.info(f"Exported encoder to {encoder_filename}") @@ -638,6 +664,7 @@ def main(): decoder, decoder_filename, opset_version=opset_version, + dynamic_batch=params.dynamic_batch == 1, ) logging.info(f"Exported decoder to {decoder_filename}") @@ -647,37 +674,39 @@ def main(): joiner, joiner_filename, opset_version=opset_version, + dynamic_batch=params.dynamic_batch == 1, ) logging.info(f"Exported joiner to {joiner_filename}") # Generate int8 quantization models # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection - logging.info("Generate int8 quantization models") + if params.enable_int8_quantization: + logging.info("Generate int8 quantization models") - encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=encoder_filename, - model_output=encoder_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) - decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=decoder_filename, - model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul", "Gather"], - weight_type=QuantType.QInt8, - ) + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul", "Gather"], + weight_type=QuantType.QInt8, + ) - joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" - quantize_dynamic( - model_input=joiner_filename, - model_output=joiner_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) if __name__ == "__main__": diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py index 298d1889b..e5e513671 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py @@ -132,10 +132,18 @@ class OnnxModel: sess_options=self.session_opts, providers=["CPUExecutionProvider"], ) + print("==========Encoder input==========") + for i in self.encoder.get_inputs(): + print(i) + print("==========Encoder output==========") + for i in self.encoder.get_outputs(): + print(i) + self.init_encoder_states() def init_encoder_states(self, batch_size: int = 1): encoder_meta = self.encoder.get_modelmeta().custom_metadata_map + print(encoder_meta) model_type = encoder_meta["model_type"] assert model_type == "zipformer", model_type @@ -232,6 +240,12 @@ class OnnxModel: sess_options=self.session_opts, providers=["CPUExecutionProvider"], ) + print("==========Decoder input==========") + for i in self.decoder.get_inputs(): + print(i) + print("==========Decoder output==========") + for i in self.decoder.get_outputs(): + print(i) decoder_meta = self.decoder.get_modelmeta().custom_metadata_map self.context_size = int(decoder_meta["context_size"]) @@ -247,6 +261,13 @@ class OnnxModel: providers=["CPUExecutionProvider"], ) + print("==========Joiner input==========") + for i in self.joiner.get_inputs(): + print(i) + print("==========Joiner output==========") + for i in self.joiner.get_outputs(): + print(i) + joiner_meta = self.joiner.get_modelmeta().custom_metadata_map self.joiner_dim = int(joiner_meta["joiner_dim"])