add export function of onnx-all-in-one to export.py

This commit is contained in:
Yunus Emre Özköse 2022-08-03 15:37:23 +03:00
parent 6af5a82d8f
commit 183821e6a0

View File

@ -111,6 +111,7 @@ with the following commands:
# You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp
"""
import onnx
import argparse
import logging
from pathlib import Path
@ -512,6 +513,25 @@ def export_joiner_model_onnx(
logging.info(f"Saved to {joiner_filename}")
def export_all_in_one_onnx(encoder_filename: str, decoder_filename: str, joiner_filename: str, all_in_one_filename: str):
encoder_onnx = onnx.load(encoder_filename)
decoder_onnx = onnx.load(decoder_filename)
joiner_onnx = onnx.load(joiner_filename)
encoder_onnx = onnx.compose.add_prefix(encoder_onnx, prefix="encoder/")
decoder_onnx = onnx.compose.add_prefix(decoder_onnx, prefix="decoder/")
joiner_onnx = onnx.compose.add_prefix(joiner_onnx, prefix="joiner/")
combined_model = onnx.compose.merge_models(
encoder_onnx, decoder_onnx, io_map={}
)
combined_model = onnx.compose.merge_models(
combined_model, joiner_onnx, io_map={}
)
onnx.save(combined_model, all_in_one_filename)
logging.info(f"Saved to {all_in_one_filename}")
@torch.no_grad()
def main():
args = get_parser().parse_args()
@ -603,6 +623,14 @@ def main():
joiner_filename,
opset_version=opset_version,
)
all_in_one_filename = params.exp_dir / "all_in_one.onnx"
export_all_in_one_onnx(
encoder_filename,
decoder_filename,
joiner_filename,
all_in_one_filename
)
elif params.jit is True:
logging.info("Using torch.jit.script()")
# We won't use the forward() method of the model in C++, so just ignore