mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
add export function of onnx-all-in-one to export.py
This commit is contained in:
parent
6af5a82d8f
commit
183821e6a0
@ -111,6 +111,7 @@ with the following commands:
|
||||
# You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp
|
||||
"""
|
||||
|
||||
import onnx
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
@ -512,6 +513,25 @@ def export_joiner_model_onnx(
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
|
||||
def export_all_in_one_onnx(encoder_filename: str, decoder_filename: str, joiner_filename: str, all_in_one_filename: str):
|
||||
encoder_onnx = onnx.load(encoder_filename)
|
||||
decoder_onnx = onnx.load(decoder_filename)
|
||||
joiner_onnx = onnx.load(joiner_filename)
|
||||
|
||||
encoder_onnx = onnx.compose.add_prefix(encoder_onnx, prefix="encoder/")
|
||||
decoder_onnx = onnx.compose.add_prefix(decoder_onnx, prefix="decoder/")
|
||||
joiner_onnx = onnx.compose.add_prefix(joiner_onnx, prefix="joiner/")
|
||||
|
||||
combined_model = onnx.compose.merge_models(
|
||||
encoder_onnx, decoder_onnx, io_map={}
|
||||
)
|
||||
combined_model = onnx.compose.merge_models(
|
||||
combined_model, joiner_onnx, io_map={}
|
||||
)
|
||||
onnx.save(combined_model, all_in_one_filename)
|
||||
logging.info(f"Saved to {all_in_one_filename}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
@ -603,6 +623,14 @@ def main():
|
||||
joiner_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
all_in_one_filename = params.exp_dir / "all_in_one.onnx"
|
||||
export_all_in_one_onnx(
|
||||
encoder_filename,
|
||||
decoder_filename,
|
||||
joiner_filename,
|
||||
all_in_one_filename
|
||||
)
|
||||
elif params.jit is True:
|
||||
logging.info("Using torch.jit.script()")
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
|
Loading…
x
Reference in New Issue
Block a user