fix style

This commit is contained in:
Yunus Emre Özköse 2022-08-03 17:05:16 +03:00
parent d382455945
commit d2de493cd8
3 changed files with 67 additions and 24 deletions

View File

@ -60,8 +60,9 @@ log "Decode with ONNX models"
--onnx-decoder-filename $repo/exp/decoder.onnx \
--onnx-joiner-filename $repo/exp/joiner.onnx
./pruned_transducer_stateless3/onnx_check_all_in_one.py --jit-filename $repo/exp/cpu_jit.pt \
--onnx-all-in-one-filename $repo/exp/all_in_one.onnx
./pruned_transducer_stateless3/onnx_check_all_in_one.py \
--jit-filename $repo/exp/cpu_jit.pt \
--onnx-all-in-one-filename $repo/exp/all_in_one.onnx
./pruned_transducer_stateless3/onnx_pretrained.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \

View File

@ -111,11 +111,11 @@ with the following commands:
# You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp
"""
import onnx
import argparse
import logging
from pathlib import Path
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
@ -513,7 +513,12 @@ def export_joiner_model_onnx(
logging.info(f"Saved to {joiner_filename}")
def export_all_in_one_onnx(encoder_filename: str, decoder_filename: str, joiner_filename: str, all_in_one_filename: str):
def export_all_in_one_onnx(
encoder_filename: str,
decoder_filename: str,
joiner_filename: str,
all_in_one_filename: str,
):
encoder_onnx = onnx.load(encoder_filename)
decoder_onnx = onnx.load(decoder_filename)
joiner_onnx = onnx.load(joiner_filename)
@ -629,7 +634,7 @@ def main():
encoder_filename,
decoder_filename,
joiner_filename,
all_in_one_filename
all_in_one_filename,
)
elif params.jit is True:
logging.info("Using torch.jit.script()")

View File

@ -21,16 +21,15 @@ This script checks that exported onnx models produce the same output
with the given torchscript model for the same input.
"""
import os
import argparse
import logging
import onnxruntime as ort
import torch
import os
import onnx
import onnxruntime
import onnx_graphsurgeon as gs
import onnxruntime
import onnxruntime as ort
import torch
ort.set_default_logger_severity(3)
@ -146,22 +145,46 @@ def test_joiner(
)
def extract_sub_model(onnx_graph: onnx.ModelProto, input_op_names: list, output_op_names: list, non_verbose=False):
def extract_sub_model(
onnx_graph: onnx.ModelProto,
input_op_names: list,
output_op_names: list,
non_verbose=False,
):
onnx_graph = onnx.shape_inference.infer_shapes(onnx_graph)
graph = gs.import_onnx(onnx_graph)
graph.cleanup().toposort()
# Extraction of input OP and output OP
graph_node_inputs = [graph_nodes for graph_nodes in graph.nodes for graph_nodes_input in graph_nodes.inputs if graph_nodes_input.name in input_op_names]
graph_node_outputs = [graph_nodes for graph_nodes in graph.nodes for graph_nodes_output in graph_nodes.outputs if graph_nodes_output.name in output_op_names]
graph_node_inputs = [
graph_nodes
for graph_nodes in graph.nodes
for graph_nodes_input in graph_nodes.inputs
if graph_nodes_input.name in input_op_names
]
graph_node_outputs = [
graph_nodes
for graph_nodes in graph.nodes
for graph_nodes_output in graph_nodes.outputs
if graph_nodes_output.name in output_op_names
]
# Init graph INPUT/OUTPUT
graph.inputs.clear()
graph.outputs.clear()
# Update graph INPUT/OUTPUT
graph.inputs = [graph_node_input for graph_node in graph_node_inputs for graph_node_input in graph_node.inputs if graph_node_input.shape]
graph.outputs = [graph_node_output for graph_node in graph_node_outputs for graph_node_output in graph_node.outputs]
graph.inputs = [
graph_node_input
for graph_node in graph_node_inputs
for graph_node_input in graph_node.inputs
if graph_node_input.shape
]
graph.outputs = [
graph_node_output
for graph_node in graph_node_outputs
for graph_node_output in graph_node.outputs
]
# Cleanup
graph.cleanup().toposort()
@ -169,20 +192,27 @@ def extract_sub_model(onnx_graph: onnx.ModelProto, input_op_names: list, output_
# Shape Estimation
extracted_graph = None
try:
extracted_graph = onnx.shape_inference.infer_shapes(gs.export_onnx(graph))
except Exception as e:
extracted_graph = onnx.shape_inference.infer_shapes(
gs.export_onnx(graph)
)
except Exception:
extracted_graph = gs.export_onnx(graph)
if not non_verbose:
print(
f'WARNING: ' +
'The input shape of the next OP does not match the output shape. ' +
'Be sure to open the .onnx file to verify the certainty of the geometry.'
"WARNING: "
+ "The input shape of the next OP does not match the output shape. "
+ "Be sure to open the .onnx file to verify the certainty of the geometry."
)
return extracted_graph
def extract_encoder(onnx_model: onnx.ModelProto):
encoder_ = extract_sub_model(onnx_model, ["encoder/x", "encoder/x_lens"], ["encoder/encoder_out", "encoder/encoder_out_lens"], False)
encoder_ = extract_sub_model(
onnx_model,
["encoder/x", "encoder/x_lens"],
["encoder/encoder_out", "encoder/encoder_out_lens"],
False,
)
onnx.save(encoder_, "tmp_encoder.onnx")
onnx.checker.check_model(encoder_)
sess = onnxruntime.InferenceSession("tmp_encoder.onnx")
@ -191,7 +221,9 @@ def extract_encoder(onnx_model: onnx.ModelProto):
def extract_decoder(onnx_model: onnx.ModelProto):
decoder_ = extract_sub_model(onnx_model, ["decoder/y"], ["decoder/decoder_out"], False)
decoder_ = extract_sub_model(
onnx_model, ["decoder/y"], ["decoder/decoder_out"], False
)
onnx.save(decoder_, "tmp_decoder.onnx")
onnx.checker.check_model(decoder_)
sess = onnxruntime.InferenceSession("tmp_decoder.onnx")
@ -200,7 +232,12 @@ def extract_decoder(onnx_model: onnx.ModelProto):
def extract_joiner(onnx_model: onnx.ModelProto):
joiner_ = extract_sub_model(onnx_model, ["joiner/encoder_out", "joiner/decoder_out"], ["joiner/logit"], False)
joiner_ = extract_sub_model(
onnx_model,
["joiner/encoder_out", "joiner/decoder_out"],
["joiner/logit"],
False,
)
onnx.save(joiner_, "tmp_joiner.onnx")
onnx.checker.check_model(joiner_)
sess = onnxruntime.InferenceSession("tmp_joiner.onnx")
@ -223,7 +260,7 @@ def main():
logging.info("Test encoder")
encoder_session = extract_encoder(onnx_model)
test_encoder(model, encoder_session)
logging.info("Test decoder")
decoder_session = extract_decoder(onnx_model)
test_decoder(model, decoder_session)