mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
use docker
This commit is contained in:
parent
6bd0c45a8d
commit
61933af28f
26
.github/scripts/docker/generate_build_matrix.py
vendored
26
.github/scripts/docker/generate_build_matrix.py
vendored
@ -10,7 +10,17 @@ def get_args():
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--min-torch-version",
|
"--min-torch-version",
|
||||||
help="Minimu torch version",
|
help="torch version",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--torch-version",
|
||||||
|
help="torch version",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--python-version",
|
||||||
|
help="python version",
|
||||||
)
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
@ -52,7 +62,7 @@ def get_torchaudio_version(torch_version):
|
|||||||
return torch_version
|
return torch_version
|
||||||
|
|
||||||
|
|
||||||
def get_matrix(min_torch_version):
|
def get_matrix(min_torch_version, specified_torch_version, specified_python_version):
|
||||||
k2_version = "1.24.4.dev20241029"
|
k2_version = "1.24.4.dev20241029"
|
||||||
kaldifeat_version = "1.25.5.dev20241029"
|
kaldifeat_version = "1.25.5.dev20241029"
|
||||||
version = "20241218"
|
version = "20241218"
|
||||||
@ -71,6 +81,12 @@ def get_matrix(min_torch_version):
|
|||||||
torch_version += ["2.5.0"]
|
torch_version += ["2.5.0"]
|
||||||
torch_version += ["2.5.1"]
|
torch_version += ["2.5.1"]
|
||||||
|
|
||||||
|
if specified_torch_version:
|
||||||
|
torch_version = [specified_torch_version]
|
||||||
|
|
||||||
|
if specified_python_version:
|
||||||
|
python_version = [specified_python_version]
|
||||||
|
|
||||||
matrix = []
|
matrix = []
|
||||||
for p in python_version:
|
for p in python_version:
|
||||||
for t in torch_version:
|
for t in torch_version:
|
||||||
@ -115,7 +131,11 @@ def get_matrix(min_torch_version):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = get_args()
|
args = get_args()
|
||||||
matrix = get_matrix(min_torch_version=args.min_torch_version)
|
matrix = get_matrix(
|
||||||
|
min_torch_version=args.min_torch_version,
|
||||||
|
specified_torch_version=args.torch_version,
|
||||||
|
specified_python_version=args.python_version,
|
||||||
|
)
|
||||||
print(json.dumps({"include": matrix}))
|
print(json.dumps({"include": matrix}))
|
||||||
|
|
||||||
|
|
||||||
|
69
.github/scripts/librispeech/ASR/run_rknn.sh
vendored
Executable file
69
.github/scripts/librispeech/ASR/run_rknn.sh
vendored
Executable file
@ -0,0 +1,69 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -ex
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
|
||||||
|
|
||||||
|
# https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed
|
||||||
|
function export_bilingual_zh_en() {
|
||||||
|
d=exp_zh_en
|
||||||
|
|
||||||
|
mkdir $d
|
||||||
|
pushd $d
|
||||||
|
|
||||||
|
curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/exp/pretrained.pt
|
||||||
|
mv pretrained.pt epoch-99.pt
|
||||||
|
|
||||||
|
curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/data/lang_char_bpe/tokens.txt
|
||||||
|
curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/data/lang_char_bpe/bpe.model
|
||||||
|
|
||||||
|
curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/test_wavs/0.wav
|
||||||
|
curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/test_wavs/BAC009S0764W0164.wav
|
||||||
|
ls -lh
|
||||||
|
popd
|
||||||
|
|
||||||
|
./pruned_transducer_stateless7_streaming/export-onnx-zh.py \
|
||||||
|
--dynamic-batch 0 \
|
||||||
|
--enable-int8-quantization 0 \
|
||||||
|
--tokens $d/tokens.txt \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir $d/ \
|
||||||
|
--decode-chunk-len 32 \
|
||||||
|
--num-encoder-layers "2,4,3,2,4" \
|
||||||
|
--feedforward-dims "1024,1024,1536,1536,1024" \
|
||||||
|
--nhead "8,8,8,8,8" \
|
||||||
|
--encoder-dims "384,384,384,384,384" \
|
||||||
|
--attention-dims "192,192,192,192,192" \
|
||||||
|
--encoder-unmasked-dims "256,256,256,256,256" \
|
||||||
|
--zipformer-downsampling-factors "1,2,4,8,2" \
|
||||||
|
--cnn-module-kernels "31,31,31,31,31" \
|
||||||
|
--decoder-dim 512 \
|
||||||
|
--joiner-dim 512
|
||||||
|
|
||||||
|
./pruned_transducer_stateless7_streaming/onnx_pretrained.py \
|
||||||
|
--encoder-model-filename $d/encoder-epoch-99-avg-1.onnx \
|
||||||
|
--decoder-model-filename $d/decoder-epoch-99-avg-1.onnx \
|
||||||
|
--joiner-model-filename $d/joiner-epoch-99-avg-1.onnx \
|
||||||
|
--tokens $d/tokens.txt \
|
||||||
|
$d/0.wav
|
||||||
|
|
||||||
|
./pruned_transducer_stateless7_streaming/onnx_pretrained.py \
|
||||||
|
--encoder-model-filename $d/encoder-epoch-99-avg-1.onnx \
|
||||||
|
--decoder-model-filename $d/decoder-epoch-99-avg-1.onnx \
|
||||||
|
--joiner-model-filename $d/joiner-epoch-99-avg-1.onnx \
|
||||||
|
--tokens $d/tokens.txt \
|
||||||
|
$d/BAC009S0764W0164.wav
|
||||||
|
|
||||||
|
ls -lh $d/
|
||||||
|
}
|
||||||
|
|
||||||
|
export_bilingual_zh_en
|
58
.github/workflows/rknn.yml
vendored
58
.github/workflows/rknn.yml
vendored
@ -17,24 +17,75 @@ concurrency:
|
|||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
generate_build_matrix:
|
||||||
|
if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa'
|
||||||
|
# see https://github.com/pytorch/pytorch/pull/50633
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
outputs:
|
||||||
|
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
- name: Generating build matrix
|
||||||
|
id: set-matrix
|
||||||
|
run: |
|
||||||
|
# outputting for debugging purposes
|
||||||
|
python ./.github/scripts/docker/generate_build_matrix.py --torch-version=2.4.0 --python-version=3.10
|
||||||
|
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --torch-version=2.4.0 --python-version=3.10)
|
||||||
|
echo "::set-output name=matrix::${MATRIX}"
|
||||||
rknn:
|
rknn:
|
||||||
|
needs: generate_build_matrix
|
||||||
|
name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }}
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
|
||||||
python-version: ["3.10"]
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
|
if: false
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Export ONNX model
|
||||||
|
uses: addnab/docker-run-action@v3
|
||||||
|
with:
|
||||||
|
image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }}
|
||||||
|
options: |
|
||||||
|
--volume ${{ github.workspace }}/:/icefall
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
cat /etc/*release
|
||||||
|
lsb_release -a
|
||||||
|
uname -a
|
||||||
|
python3 --version
|
||||||
|
export PYTHONPATH=/icefall:$PYTHONPATH
|
||||||
|
cd /icefall
|
||||||
|
git config --global --add safe.directory /icefall
|
||||||
|
|
||||||
|
python3 -m torch.utils.collect_env
|
||||||
|
python3 -m k2.version
|
||||||
|
pip list
|
||||||
|
|
||||||
|
.github/scripts/librispeech/ASR/run_rknn.sh
|
||||||
|
|
||||||
|
# Install rknn
|
||||||
|
curl -SL -O https://huggingface.co/csukuangfj/rknn-toolkit2/resolve/main/rknn_toolkit2-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||||
|
pip install ./*.whl "numpy<=1.26.4"
|
||||||
|
pip list | grep rknn
|
||||||
|
echo "---"
|
||||||
|
pip list
|
||||||
|
echo "---"
|
||||||
|
|
||||||
- name: Install rknn
|
- name: Install rknn
|
||||||
|
if: false
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
ls
|
ls
|
||||||
@ -46,6 +97,7 @@ jobs:
|
|||||||
echo "---"
|
echo "---"
|
||||||
|
|
||||||
- name: Run
|
- name: Run
|
||||||
|
if: false
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
ls
|
ls
|
||||||
@ -74,11 +126,13 @@ jobs:
|
|||||||
ls -lh ../model
|
ls -lh ../model
|
||||||
|
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
|
if: false
|
||||||
with:
|
with:
|
||||||
name: rknn-files
|
name: rknn-files
|
||||||
path: ./rknn_model_zoo/examples/zipformer/model/*.rknn
|
path: ./rknn_model_zoo/examples/zipformer/model/*.rknn
|
||||||
|
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
|
if: false
|
||||||
with:
|
with:
|
||||||
name: onnx-files
|
name: onnx-files
|
||||||
path: ./rknn_model_zoo/examples/zipformer/model/*.onnx
|
path: ./rknn_model_zoo/examples/zipformer/model/*.onnx
|
||||||
|
@ -85,6 +85,20 @@ def get_parser():
|
|||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dynamic-batch",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="1 to support dynamic batch size. 0 to support only batch size == 1",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-int8-quantization",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="1 to also export int8 onnx models.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
@ -257,6 +271,7 @@ def export_encoder_model_onnx(
|
|||||||
encoder_model: OnnxEncoder,
|
encoder_model: OnnxEncoder,
|
||||||
encoder_filename: str,
|
encoder_filename: str,
|
||||||
opset_version: int = 11,
|
opset_version: int = 11,
|
||||||
|
dynamic_batch: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Onnx model inputs:
|
Onnx model inputs:
|
||||||
@ -274,6 +289,8 @@ def export_encoder_model_onnx(
|
|||||||
The filename to save the exported ONNX model.
|
The filename to save the exported ONNX model.
|
||||||
opset_version:
|
opset_version:
|
||||||
The opset version to use.
|
The opset version to use.
|
||||||
|
dynamic_batch:
|
||||||
|
True to export a model supporting dynamic batch size
|
||||||
"""
|
"""
|
||||||
|
|
||||||
encoder_model.encoder.__class__.forward = (
|
encoder_model.encoder.__class__.forward = (
|
||||||
@ -379,7 +396,9 @@ def export_encoder_model_onnx(
|
|||||||
"encoder_out": {0: "N"},
|
"encoder_out": {0: "N"},
|
||||||
**inputs,
|
**inputs,
|
||||||
**outputs,
|
**outputs,
|
||||||
},
|
}
|
||||||
|
if dynamic_batch
|
||||||
|
else {},
|
||||||
)
|
)
|
||||||
|
|
||||||
add_meta_data(filename=encoder_filename, meta_data=meta_data)
|
add_meta_data(filename=encoder_filename, meta_data=meta_data)
|
||||||
@ -389,6 +408,7 @@ def export_decoder_model_onnx(
|
|||||||
decoder_model: nn.Module,
|
decoder_model: nn.Module,
|
||||||
decoder_filename: str,
|
decoder_filename: str,
|
||||||
opset_version: int = 11,
|
opset_version: int = 11,
|
||||||
|
dynamic_batch: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Export the decoder model to ONNX format.
|
"""Export the decoder model to ONNX format.
|
||||||
|
|
||||||
@ -412,7 +432,7 @@ def export_decoder_model_onnx(
|
|||||||
"""
|
"""
|
||||||
context_size = decoder_model.decoder.context_size
|
context_size = decoder_model.decoder.context_size
|
||||||
vocab_size = decoder_model.decoder.vocab_size
|
vocab_size = decoder_model.decoder.vocab_size
|
||||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
y = torch.zeros(1, context_size, dtype=torch.int64)
|
||||||
decoder_model = torch.jit.script(decoder_model)
|
decoder_model = torch.jit.script(decoder_model)
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
decoder_model,
|
decoder_model,
|
||||||
@ -425,7 +445,9 @@ def export_decoder_model_onnx(
|
|||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"y": {0: "N"},
|
"y": {0: "N"},
|
||||||
"decoder_out": {0: "N"},
|
"decoder_out": {0: "N"},
|
||||||
},
|
}
|
||||||
|
if dynamic_batch
|
||||||
|
else {},
|
||||||
)
|
)
|
||||||
meta_data = {
|
meta_data = {
|
||||||
"context_size": str(context_size),
|
"context_size": str(context_size),
|
||||||
@ -438,6 +460,7 @@ def export_joiner_model_onnx(
|
|||||||
joiner_model: nn.Module,
|
joiner_model: nn.Module,
|
||||||
joiner_filename: str,
|
joiner_filename: str,
|
||||||
opset_version: int = 11,
|
opset_version: int = 11,
|
||||||
|
dynamic_batch: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Export the joiner model to ONNX format.
|
"""Export the joiner model to ONNX format.
|
||||||
The exported joiner model has two inputs:
|
The exported joiner model has two inputs:
|
||||||
@ -452,8 +475,8 @@ def export_joiner_model_onnx(
|
|||||||
joiner_dim = joiner_model.output_linear.weight.shape[1]
|
joiner_dim = joiner_model.output_linear.weight.shape[1]
|
||||||
logging.info(f"joiner dim: {joiner_dim}")
|
logging.info(f"joiner dim: {joiner_dim}")
|
||||||
|
|
||||||
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
|
||||||
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
|
||||||
|
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
joiner_model,
|
joiner_model,
|
||||||
@ -470,7 +493,9 @@ def export_joiner_model_onnx(
|
|||||||
"encoder_out": {0: "N"},
|
"encoder_out": {0: "N"},
|
||||||
"decoder_out": {0: "N"},
|
"decoder_out": {0: "N"},
|
||||||
"logit": {0: "N"},
|
"logit": {0: "N"},
|
||||||
},
|
}
|
||||||
|
if dynamic_batch
|
||||||
|
else {},
|
||||||
)
|
)
|
||||||
meta_data = {
|
meta_data = {
|
||||||
"joiner_dim": str(joiner_dim),
|
"joiner_dim": str(joiner_dim),
|
||||||
@ -629,6 +654,7 @@ def main():
|
|||||||
encoder,
|
encoder,
|
||||||
encoder_filename,
|
encoder_filename,
|
||||||
opset_version=opset_version,
|
opset_version=opset_version,
|
||||||
|
dynamic_batch=params.dynamic_batch == 1,
|
||||||
)
|
)
|
||||||
logging.info(f"Exported encoder to {encoder_filename}")
|
logging.info(f"Exported encoder to {encoder_filename}")
|
||||||
|
|
||||||
@ -638,6 +664,7 @@ def main():
|
|||||||
decoder,
|
decoder,
|
||||||
decoder_filename,
|
decoder_filename,
|
||||||
opset_version=opset_version,
|
opset_version=opset_version,
|
||||||
|
dynamic_batch=params.dynamic_batch == 1,
|
||||||
)
|
)
|
||||||
logging.info(f"Exported decoder to {decoder_filename}")
|
logging.info(f"Exported decoder to {decoder_filename}")
|
||||||
|
|
||||||
@ -647,37 +674,39 @@ def main():
|
|||||||
joiner,
|
joiner,
|
||||||
joiner_filename,
|
joiner_filename,
|
||||||
opset_version=opset_version,
|
opset_version=opset_version,
|
||||||
|
dynamic_batch=params.dynamic_batch == 1,
|
||||||
)
|
)
|
||||||
logging.info(f"Exported joiner to {joiner_filename}")
|
logging.info(f"Exported joiner to {joiner_filename}")
|
||||||
|
|
||||||
# Generate int8 quantization models
|
# Generate int8 quantization models
|
||||||
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||||
|
|
||||||
logging.info("Generate int8 quantization models")
|
if params.enable_int8_quantization:
|
||||||
|
logging.info("Generate int8 quantization models")
|
||||||
|
|
||||||
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
|
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
|
||||||
quantize_dynamic(
|
quantize_dynamic(
|
||||||
model_input=encoder_filename,
|
model_input=encoder_filename,
|
||||||
model_output=encoder_filename_int8,
|
model_output=encoder_filename_int8,
|
||||||
op_types_to_quantize=["MatMul"],
|
op_types_to_quantize=["MatMul"],
|
||||||
weight_type=QuantType.QInt8,
|
weight_type=QuantType.QInt8,
|
||||||
)
|
)
|
||||||
|
|
||||||
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
|
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
|
||||||
quantize_dynamic(
|
quantize_dynamic(
|
||||||
model_input=decoder_filename,
|
model_input=decoder_filename,
|
||||||
model_output=decoder_filename_int8,
|
model_output=decoder_filename_int8,
|
||||||
op_types_to_quantize=["MatMul", "Gather"],
|
op_types_to_quantize=["MatMul", "Gather"],
|
||||||
weight_type=QuantType.QInt8,
|
weight_type=QuantType.QInt8,
|
||||||
)
|
)
|
||||||
|
|
||||||
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
|
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
|
||||||
quantize_dynamic(
|
quantize_dynamic(
|
||||||
model_input=joiner_filename,
|
model_input=joiner_filename,
|
||||||
model_output=joiner_filename_int8,
|
model_output=joiner_filename_int8,
|
||||||
op_types_to_quantize=["MatMul"],
|
op_types_to_quantize=["MatMul"],
|
||||||
weight_type=QuantType.QInt8,
|
weight_type=QuantType.QInt8,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -132,10 +132,18 @@ class OnnxModel:
|
|||||||
sess_options=self.session_opts,
|
sess_options=self.session_opts,
|
||||||
providers=["CPUExecutionProvider"],
|
providers=["CPUExecutionProvider"],
|
||||||
)
|
)
|
||||||
|
print("==========Encoder input==========")
|
||||||
|
for i in self.encoder.get_inputs():
|
||||||
|
print(i)
|
||||||
|
print("==========Encoder output==========")
|
||||||
|
for i in self.encoder.get_outputs():
|
||||||
|
print(i)
|
||||||
|
|
||||||
self.init_encoder_states()
|
self.init_encoder_states()
|
||||||
|
|
||||||
def init_encoder_states(self, batch_size: int = 1):
|
def init_encoder_states(self, batch_size: int = 1):
|
||||||
encoder_meta = self.encoder.get_modelmeta().custom_metadata_map
|
encoder_meta = self.encoder.get_modelmeta().custom_metadata_map
|
||||||
|
print(encoder_meta)
|
||||||
|
|
||||||
model_type = encoder_meta["model_type"]
|
model_type = encoder_meta["model_type"]
|
||||||
assert model_type == "zipformer", model_type
|
assert model_type == "zipformer", model_type
|
||||||
@ -232,6 +240,12 @@ class OnnxModel:
|
|||||||
sess_options=self.session_opts,
|
sess_options=self.session_opts,
|
||||||
providers=["CPUExecutionProvider"],
|
providers=["CPUExecutionProvider"],
|
||||||
)
|
)
|
||||||
|
print("==========Decoder input==========")
|
||||||
|
for i in self.decoder.get_inputs():
|
||||||
|
print(i)
|
||||||
|
print("==========Decoder output==========")
|
||||||
|
for i in self.decoder.get_outputs():
|
||||||
|
print(i)
|
||||||
|
|
||||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||||
self.context_size = int(decoder_meta["context_size"])
|
self.context_size = int(decoder_meta["context_size"])
|
||||||
@ -247,6 +261,13 @@ class OnnxModel:
|
|||||||
providers=["CPUExecutionProvider"],
|
providers=["CPUExecutionProvider"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print("==========Joiner input==========")
|
||||||
|
for i in self.joiner.get_inputs():
|
||||||
|
print(i)
|
||||||
|
print("==========Joiner output==========")
|
||||||
|
for i in self.joiner.get_outputs():
|
||||||
|
print(i)
|
||||||
|
|
||||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||||
self.joiner_dim = int(joiner_meta["joiner_dim"])
|
self.joiner_dim = int(joiner_meta["joiner_dim"])
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user