From 7157f62af3b7712eda2186f7e3d253df7cde65b5 Mon Sep 17 00:00:00 2001 From: Yunusemre Date: Thu, 4 Aug 2022 18:03:41 +0300 Subject: [PATCH] Merging onnx models (#518) * add export function of onnx-all-in-one to export.py * add onnx_check script for all-in-one onnx model * minor fix * remove unused arguments * add onnx-all-in-one test * fix style * fix style * fix requirements * fix input/output names * fix installing onnx_graphsurgeon * fix instaliing onnx_graphsurgeon * revert to previous requirements.txt * fix minor --- ...pruned-transducer-stateless3-2022-05-13.sh | 4 + .../pruned_transducer_stateless3/export.py | 33 ++ .../onnx_check_all_in_one.py | 284 ++++++++++++++++++ requirements-ci.txt | 1 + requirements.txt | 2 + 5 files changed, 324 insertions(+) create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check_all_in_one.py diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh index 2deab04b9..bdc8a3838 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh @@ -60,6 +60,10 @@ log "Decode with ONNX models" --onnx-decoder-filename $repo/exp/decoder.onnx \ --onnx-joiner-filename $repo/exp/joiner.onnx +./pruned_transducer_stateless3/onnx_check_all_in_one.py \ + --jit-filename $repo/exp/cpu_jit.pt \ + --onnx-all-in-one-filename $repo/exp/all_in_one.onnx + ./pruned_transducer_stateless3/onnx_pretrained.py \ --bpe-model $repo/data/lang_bpe_500/bpe.model \ --encoder-model-filename $repo/exp/encoder.onnx \ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index 1485c6d6a..2bb518bcd 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -115,6 +115,7 @@ import argparse import logging from pathlib import Path +import onnx import sentencepiece as spm import torch import torch.nn as nn @@ -512,6 +513,30 @@ def export_joiner_model_onnx( logging.info(f"Saved to {joiner_filename}") +def export_all_in_one_onnx( + encoder_filename: str, + decoder_filename: str, + joiner_filename: str, + all_in_one_filename: str, +): + encoder_onnx = onnx.load(encoder_filename) + decoder_onnx = onnx.load(decoder_filename) + joiner_onnx = onnx.load(joiner_filename) + + encoder_onnx = onnx.compose.add_prefix(encoder_onnx, prefix="encoder/") + decoder_onnx = onnx.compose.add_prefix(decoder_onnx, prefix="decoder/") + joiner_onnx = onnx.compose.add_prefix(joiner_onnx, prefix="joiner/") + + combined_model = onnx.compose.merge_models( + encoder_onnx, decoder_onnx, io_map={} + ) + combined_model = onnx.compose.merge_models( + combined_model, joiner_onnx, io_map={} + ) + onnx.save(combined_model, all_in_one_filename) + logging.info(f"Saved to {all_in_one_filename}") + + @torch.no_grad() def main(): args = get_parser().parse_args() @@ -603,6 +628,14 @@ def main(): joiner_filename, opset_version=opset_version, ) + + all_in_one_filename = params.exp_dir / "all_in_one.onnx" + export_all_in_one_onnx( + encoder_filename, + decoder_filename, + joiner_filename, + all_in_one_filename, + ) elif params.jit is True: logging.info("Using torch.jit.script()") # We won't use the forward() method of the model in C++, so just ignore diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check_all_in_one.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check_all_in_one.py new file mode 100755 index 000000000..b4cf8c94a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check_all_in_one.py @@ -0,0 +1,284 @@ +#!/usr/bin/env python3 +# +# Copyright 2022 Xiaomi Corporation (Author: Yunus Emre Ozkose) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script checks that exported onnx models produce the same output +with the given torchscript model for the same input. +""" + +import argparse +import logging +import os + +import onnx +import onnx_graphsurgeon as gs +import onnxruntime +import onnxruntime as ort +import torch + +ort.set_default_logger_severity(3) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--jit-filename", + required=True, + type=str, + help="Path to the torchscript model", + ) + + parser.add_argument( + "--onnx-all-in-one-filename", + required=True, + type=str, + help="Path to the onnx all in one model", + ) + + return parser + + +def test_encoder( + model: torch.jit.ScriptModule, + encoder_session: ort.InferenceSession, +): + encoder_inputs = encoder_session.get_inputs() + assert encoder_inputs[0].shape == ["N", "T", 80] + assert encoder_inputs[1].shape == ["N"] + encoder_input_names = [i.name for i in encoder_inputs] + encoder_output_names = [i.name for i in encoder_session.get_outputs()] + + for N in [1, 5]: + for T in [12, 25]: + print("N, T", N, T) + x = torch.rand(N, T, 80, dtype=torch.float32) + x_lens = torch.randint(low=10, high=T + 1, size=(N,)) + x_lens[0] = T + + encoder_inputs = { + encoder_input_names[0]: x.numpy(), + encoder_input_names[1]: x_lens.numpy(), + } + encoder_out, encoder_out_lens = encoder_session.run( + [encoder_output_names[1], encoder_output_names[0]], + encoder_inputs, + ) + + torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens) + + encoder_out = torch.from_numpy(encoder_out) + assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), ( + (encoder_out - torch_encoder_out).abs().max() + ) + + +def test_decoder( + model: torch.jit.ScriptModule, + decoder_session: ort.InferenceSession, +): + decoder_inputs = decoder_session.get_inputs() + assert decoder_inputs[0].shape == ["N", 2] + decoder_input_names = [i.name for i in decoder_inputs] + decoder_output_names = [i.name for i in decoder_session.get_outputs()] + + for N in [1, 5, 10]: + y = torch.randint(low=1, high=500, size=(10, 2)) + + decoder_inputs = {decoder_input_names[0]: y.numpy()} + decoder_out = decoder_session.run( + [decoder_output_names[0]], + decoder_inputs, + )[0] + decoder_out = torch.from_numpy(decoder_out) + + torch_decoder_out = model.decoder(y, need_pad=False) + assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), ( + (decoder_out - torch_decoder_out).abs().max() + ) + + +def test_joiner( + model: torch.jit.ScriptModule, + joiner_session: ort.InferenceSession, +): + joiner_inputs = joiner_session.get_inputs() + assert joiner_inputs[0].shape == ["N", 512] + assert joiner_inputs[1].shape == ["N", 512] + joiner_input_names = [i.name for i in joiner_inputs] + joiner_output_names = [i.name for i in joiner_session.get_outputs()] + + for N in [1, 5, 10]: + encoder_out = torch.rand(N, 512) + decoder_out = torch.rand(N, 512) + + joiner_inputs = { + joiner_input_names[0]: encoder_out.numpy(), + joiner_input_names[1]: decoder_out.numpy(), + } + joiner_out = joiner_session.run( + [joiner_output_names[0]], joiner_inputs + )[0] + joiner_out = torch.from_numpy(joiner_out) + + torch_joiner_out = model.joiner( + encoder_out, + decoder_out, + project_input=True, + ) + assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), ( + (joiner_out - torch_joiner_out).abs().max() + ) + + +def extract_sub_model( + onnx_graph: onnx.ModelProto, + input_op_names: list, + output_op_names: list, + non_verbose=False, +): + onnx_graph = onnx.shape_inference.infer_shapes(onnx_graph) + graph = gs.import_onnx(onnx_graph) + graph.cleanup().toposort() + + # Extraction of input OP and output OP + graph_node_inputs = [ + graph_nodes + for graph_nodes in graph.nodes + for graph_nodes_input in graph_nodes.inputs + if graph_nodes_input.name in input_op_names + ] + graph_node_outputs = [ + graph_nodes + for graph_nodes in graph.nodes + for graph_nodes_output in graph_nodes.outputs + if graph_nodes_output.name in output_op_names + ] + + # Init graph INPUT/OUTPUT + graph.inputs.clear() + graph.outputs.clear() + + # Update graph INPUT/OUTPUT + graph.inputs = [ + graph_node_input + for graph_node in graph_node_inputs + for graph_node_input in graph_node.inputs + if graph_node_input.shape + ] + graph.outputs = [ + graph_node_output + for graph_node in graph_node_outputs + for graph_node_output in graph_node.outputs + ] + + # Cleanup + graph.cleanup().toposort() + + # Shape Estimation + extracted_graph = None + try: + extracted_graph = onnx.shape_inference.infer_shapes( + gs.export_onnx(graph) + ) + except Exception: + extracted_graph = gs.export_onnx(graph) + if not non_verbose: + print( + "WARNING: " + + "The input shape of the next OP does not match the output shape. " + + "Be sure to open the .onnx file to verify the certainty of the geometry." + ) + return extracted_graph + + +def extract_encoder(onnx_model: onnx.ModelProto): + encoder_ = extract_sub_model( + onnx_model, + ["encoder/x", "encoder/x_lens"], + ["encoder/encoder_out", "encoder/encoder_out_lens"], + False, + ) + onnx.save(encoder_, "tmp_encoder.onnx") + onnx.checker.check_model(encoder_) + sess = onnxruntime.InferenceSession("tmp_encoder.onnx") + os.remove("tmp_encoder.onnx") + return sess + + +def extract_decoder(onnx_model: onnx.ModelProto): + decoder_ = extract_sub_model( + onnx_model, ["decoder/y"], ["decoder/decoder_out"], False + ) + onnx.save(decoder_, "tmp_decoder.onnx") + onnx.checker.check_model(decoder_) + sess = onnxruntime.InferenceSession("tmp_decoder.onnx") + os.remove("tmp_decoder.onnx") + return sess + + +def extract_joiner(onnx_model: onnx.ModelProto): + joiner_ = extract_sub_model( + onnx_model, + ["joiner/encoder_out", "joiner/decoder_out"], + ["joiner/logit"], + False, + ) + onnx.save(joiner_, "tmp_joiner.onnx") + onnx.checker.check_model(joiner_) + sess = onnxruntime.InferenceSession("tmp_joiner.onnx") + os.remove("tmp_joiner.onnx") + return sess + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + logging.info(vars(args)) + + model = torch.jit.load(args.jit_filename) + onnx_model = onnx.load(args.onnx_all_in_one_filename) + + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + logging.info("Test encoder") + encoder_session = extract_encoder(onnx_model) + test_encoder(model, encoder_session) + + logging.info("Test decoder") + decoder_session = extract_decoder(onnx_model) + test_decoder(model, decoder_session) + + logging.info("Test joiner") + joiner_session = extract_joiner(onnx_model) + test_joiner(model, joiner_session) + logging.info("Finished checking ONNX models") + + +if __name__ == "__main__": + torch.manual_seed(20220727) + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/requirements-ci.txt b/requirements-ci.txt index 48769d61a..385c8737e 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -23,3 +23,4 @@ multi_quantization onnx onnxruntime +onnx_graphsurgeon -i https://pypi.ngc.nvidia.com diff --git a/requirements.txt b/requirements.txt index 25b5529f0..2e72d2eb6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,5 @@ typeguard multi_quantization onnx onnxruntime +--extra-index-url https://pypi.ngc.nvidia.com +onnx_graphsurgeon