diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh index f27bbe308..bdc8a3838 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh @@ -60,8 +60,9 @@ log "Decode with ONNX models" --onnx-decoder-filename $repo/exp/decoder.onnx \ --onnx-joiner-filename $repo/exp/joiner.onnx -./pruned_transducer_stateless3/onnx_check_all_in_one.py --jit-filename $repo/exp/cpu_jit.pt \ - --onnx-all-in-one-filename $repo/exp/all_in_one.onnx +./pruned_transducer_stateless3/onnx_check_all_in_one.py \ + --jit-filename $repo/exp/cpu_jit.pt \ + --onnx-all-in-one-filename $repo/exp/all_in_one.onnx ./pruned_transducer_stateless3/onnx_pretrained.py \ --bpe-model $repo/data/lang_bpe_500/bpe.model \ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index 2f7c1022e..2bb518bcd 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -111,11 +111,11 @@ with the following commands: # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp """ -import onnx import argparse import logging from pathlib import Path +import onnx import sentencepiece as spm import torch import torch.nn as nn @@ -513,7 +513,12 @@ def export_joiner_model_onnx( logging.info(f"Saved to {joiner_filename}") -def export_all_in_one_onnx(encoder_filename: str, decoder_filename: str, joiner_filename: str, all_in_one_filename: str): +def export_all_in_one_onnx( + encoder_filename: str, + decoder_filename: str, + joiner_filename: str, + all_in_one_filename: str, +): encoder_onnx = onnx.load(encoder_filename) decoder_onnx = onnx.load(decoder_filename) joiner_onnx = onnx.load(joiner_filename) @@ -629,7 +634,7 @@ def main(): encoder_filename, decoder_filename, joiner_filename, - all_in_one_filename + all_in_one_filename, ) elif params.jit is True: logging.info("Using torch.jit.script()") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check_all_in_one.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check_all_in_one.py index fd9af083e..d9a23e1b6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check_all_in_one.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check_all_in_one.py @@ -21,16 +21,15 @@ This script checks that exported onnx models produce the same output with the given torchscript model for the same input. """ -import os import argparse import logging - -import onnxruntime as ort -import torch +import os import onnx -import onnxruntime import onnx_graphsurgeon as gs +import onnxruntime +import onnxruntime as ort +import torch ort.set_default_logger_severity(3) @@ -146,22 +145,46 @@ def test_joiner( ) -def extract_sub_model(onnx_graph: onnx.ModelProto, input_op_names: list, output_op_names: list, non_verbose=False): +def extract_sub_model( + onnx_graph: onnx.ModelProto, + input_op_names: list, + output_op_names: list, + non_verbose=False, +): onnx_graph = onnx.shape_inference.infer_shapes(onnx_graph) graph = gs.import_onnx(onnx_graph) graph.cleanup().toposort() # Extraction of input OP and output OP - graph_node_inputs = [graph_nodes for graph_nodes in graph.nodes for graph_nodes_input in graph_nodes.inputs if graph_nodes_input.name in input_op_names] - graph_node_outputs = [graph_nodes for graph_nodes in graph.nodes for graph_nodes_output in graph_nodes.outputs if graph_nodes_output.name in output_op_names] + graph_node_inputs = [ + graph_nodes + for graph_nodes in graph.nodes + for graph_nodes_input in graph_nodes.inputs + if graph_nodes_input.name in input_op_names + ] + graph_node_outputs = [ + graph_nodes + for graph_nodes in graph.nodes + for graph_nodes_output in graph_nodes.outputs + if graph_nodes_output.name in output_op_names + ] # Init graph INPUT/OUTPUT graph.inputs.clear() graph.outputs.clear() # Update graph INPUT/OUTPUT - graph.inputs = [graph_node_input for graph_node in graph_node_inputs for graph_node_input in graph_node.inputs if graph_node_input.shape] - graph.outputs = [graph_node_output for graph_node in graph_node_outputs for graph_node_output in graph_node.outputs] + graph.inputs = [ + graph_node_input + for graph_node in graph_node_inputs + for graph_node_input in graph_node.inputs + if graph_node_input.shape + ] + graph.outputs = [ + graph_node_output + for graph_node in graph_node_outputs + for graph_node_output in graph_node.outputs + ] # Cleanup graph.cleanup().toposort() @@ -169,20 +192,27 @@ def extract_sub_model(onnx_graph: onnx.ModelProto, input_op_names: list, output_ # Shape Estimation extracted_graph = None try: - extracted_graph = onnx.shape_inference.infer_shapes(gs.export_onnx(graph)) - except Exception as e: + extracted_graph = onnx.shape_inference.infer_shapes( + gs.export_onnx(graph) + ) + except Exception: extracted_graph = gs.export_onnx(graph) if not non_verbose: print( - f'WARNING: ' + - 'The input shape of the next OP does not match the output shape. ' + - 'Be sure to open the .onnx file to verify the certainty of the geometry.' + "WARNING: " + + "The input shape of the next OP does not match the output shape. " + + "Be sure to open the .onnx file to verify the certainty of the geometry." ) return extracted_graph def extract_encoder(onnx_model: onnx.ModelProto): - encoder_ = extract_sub_model(onnx_model, ["encoder/x", "encoder/x_lens"], ["encoder/encoder_out", "encoder/encoder_out_lens"], False) + encoder_ = extract_sub_model( + onnx_model, + ["encoder/x", "encoder/x_lens"], + ["encoder/encoder_out", "encoder/encoder_out_lens"], + False, + ) onnx.save(encoder_, "tmp_encoder.onnx") onnx.checker.check_model(encoder_) sess = onnxruntime.InferenceSession("tmp_encoder.onnx") @@ -191,7 +221,9 @@ def extract_encoder(onnx_model: onnx.ModelProto): def extract_decoder(onnx_model: onnx.ModelProto): - decoder_ = extract_sub_model(onnx_model, ["decoder/y"], ["decoder/decoder_out"], False) + decoder_ = extract_sub_model( + onnx_model, ["decoder/y"], ["decoder/decoder_out"], False + ) onnx.save(decoder_, "tmp_decoder.onnx") onnx.checker.check_model(decoder_) sess = onnxruntime.InferenceSession("tmp_decoder.onnx") @@ -200,7 +232,12 @@ def extract_decoder(onnx_model: onnx.ModelProto): def extract_joiner(onnx_model: onnx.ModelProto): - joiner_ = extract_sub_model(onnx_model, ["joiner/encoder_out", "joiner/decoder_out"], ["joiner/logit"], False) + joiner_ = extract_sub_model( + onnx_model, + ["joiner/encoder_out", "joiner/decoder_out"], + ["joiner/logit"], + False, + ) onnx.save(joiner_, "tmp_joiner.onnx") onnx.checker.check_model(joiner_) sess = onnxruntime.InferenceSession("tmp_joiner.onnx") @@ -223,7 +260,7 @@ def main(): logging.info("Test encoder") encoder_session = extract_encoder(onnx_model) test_encoder(model, encoder_session) - + logging.info("Test decoder") decoder_session = extract_decoder(onnx_model) test_decoder(model, decoder_session)