fix style
This commit is contained in:
parent
d382455945
commit
d2de493cd8
@ -60,8 +60,9 @@ log "Decode with ONNX models"
|
|||||||
--onnx-decoder-filename $repo/exp/decoder.onnx \
|
--onnx-decoder-filename $repo/exp/decoder.onnx \
|
||||||
--onnx-joiner-filename $repo/exp/joiner.onnx
|
--onnx-joiner-filename $repo/exp/joiner.onnx
|
||||||
|
|
||||||
./pruned_transducer_stateless3/onnx_check_all_in_one.py --jit-filename $repo/exp/cpu_jit.pt \
|
./pruned_transducer_stateless3/onnx_check_all_in_one.py \
|
||||||
--onnx-all-in-one-filename $repo/exp/all_in_one.onnx
|
--jit-filename $repo/exp/cpu_jit.pt \
|
||||||
|
--onnx-all-in-one-filename $repo/exp/all_in_one.onnx
|
||||||
|
|
||||||
./pruned_transducer_stateless3/onnx_pretrained.py \
|
./pruned_transducer_stateless3/onnx_pretrained.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
|||||||
@ -111,11 +111,11 @@ with the following commands:
|
|||||||
# You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp
|
# You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import onnx
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import onnx
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -513,7 +513,12 @@ def export_joiner_model_onnx(
|
|||||||
logging.info(f"Saved to {joiner_filename}")
|
logging.info(f"Saved to {joiner_filename}")
|
||||||
|
|
||||||
|
|
||||||
def export_all_in_one_onnx(encoder_filename: str, decoder_filename: str, joiner_filename: str, all_in_one_filename: str):
|
def export_all_in_one_onnx(
|
||||||
|
encoder_filename: str,
|
||||||
|
decoder_filename: str,
|
||||||
|
joiner_filename: str,
|
||||||
|
all_in_one_filename: str,
|
||||||
|
):
|
||||||
encoder_onnx = onnx.load(encoder_filename)
|
encoder_onnx = onnx.load(encoder_filename)
|
||||||
decoder_onnx = onnx.load(decoder_filename)
|
decoder_onnx = onnx.load(decoder_filename)
|
||||||
joiner_onnx = onnx.load(joiner_filename)
|
joiner_onnx = onnx.load(joiner_filename)
|
||||||
@ -629,7 +634,7 @@ def main():
|
|||||||
encoder_filename,
|
encoder_filename,
|
||||||
decoder_filename,
|
decoder_filename,
|
||||||
joiner_filename,
|
joiner_filename,
|
||||||
all_in_one_filename
|
all_in_one_filename,
|
||||||
)
|
)
|
||||||
elif params.jit is True:
|
elif params.jit is True:
|
||||||
logging.info("Using torch.jit.script()")
|
logging.info("Using torch.jit.script()")
|
||||||
|
|||||||
@ -21,16 +21,15 @@ This script checks that exported onnx models produce the same output
|
|||||||
with the given torchscript model for the same input.
|
with the given torchscript model for the same input.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import onnxruntime as ort
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import onnx
|
import onnx
|
||||||
import onnxruntime
|
|
||||||
import onnx_graphsurgeon as gs
|
import onnx_graphsurgeon as gs
|
||||||
|
import onnxruntime
|
||||||
|
import onnxruntime as ort
|
||||||
|
import torch
|
||||||
|
|
||||||
ort.set_default_logger_severity(3)
|
ort.set_default_logger_severity(3)
|
||||||
|
|
||||||
@ -146,22 +145,46 @@ def test_joiner(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def extract_sub_model(onnx_graph: onnx.ModelProto, input_op_names: list, output_op_names: list, non_verbose=False):
|
def extract_sub_model(
|
||||||
|
onnx_graph: onnx.ModelProto,
|
||||||
|
input_op_names: list,
|
||||||
|
output_op_names: list,
|
||||||
|
non_verbose=False,
|
||||||
|
):
|
||||||
onnx_graph = onnx.shape_inference.infer_shapes(onnx_graph)
|
onnx_graph = onnx.shape_inference.infer_shapes(onnx_graph)
|
||||||
graph = gs.import_onnx(onnx_graph)
|
graph = gs.import_onnx(onnx_graph)
|
||||||
graph.cleanup().toposort()
|
graph.cleanup().toposort()
|
||||||
|
|
||||||
# Extraction of input OP and output OP
|
# Extraction of input OP and output OP
|
||||||
graph_node_inputs = [graph_nodes for graph_nodes in graph.nodes for graph_nodes_input in graph_nodes.inputs if graph_nodes_input.name in input_op_names]
|
graph_node_inputs = [
|
||||||
graph_node_outputs = [graph_nodes for graph_nodes in graph.nodes for graph_nodes_output in graph_nodes.outputs if graph_nodes_output.name in output_op_names]
|
graph_nodes
|
||||||
|
for graph_nodes in graph.nodes
|
||||||
|
for graph_nodes_input in graph_nodes.inputs
|
||||||
|
if graph_nodes_input.name in input_op_names
|
||||||
|
]
|
||||||
|
graph_node_outputs = [
|
||||||
|
graph_nodes
|
||||||
|
for graph_nodes in graph.nodes
|
||||||
|
for graph_nodes_output in graph_nodes.outputs
|
||||||
|
if graph_nodes_output.name in output_op_names
|
||||||
|
]
|
||||||
|
|
||||||
# Init graph INPUT/OUTPUT
|
# Init graph INPUT/OUTPUT
|
||||||
graph.inputs.clear()
|
graph.inputs.clear()
|
||||||
graph.outputs.clear()
|
graph.outputs.clear()
|
||||||
|
|
||||||
# Update graph INPUT/OUTPUT
|
# Update graph INPUT/OUTPUT
|
||||||
graph.inputs = [graph_node_input for graph_node in graph_node_inputs for graph_node_input in graph_node.inputs if graph_node_input.shape]
|
graph.inputs = [
|
||||||
graph.outputs = [graph_node_output for graph_node in graph_node_outputs for graph_node_output in graph_node.outputs]
|
graph_node_input
|
||||||
|
for graph_node in graph_node_inputs
|
||||||
|
for graph_node_input in graph_node.inputs
|
||||||
|
if graph_node_input.shape
|
||||||
|
]
|
||||||
|
graph.outputs = [
|
||||||
|
graph_node_output
|
||||||
|
for graph_node in graph_node_outputs
|
||||||
|
for graph_node_output in graph_node.outputs
|
||||||
|
]
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
graph.cleanup().toposort()
|
graph.cleanup().toposort()
|
||||||
@ -169,20 +192,27 @@ def extract_sub_model(onnx_graph: onnx.ModelProto, input_op_names: list, output_
|
|||||||
# Shape Estimation
|
# Shape Estimation
|
||||||
extracted_graph = None
|
extracted_graph = None
|
||||||
try:
|
try:
|
||||||
extracted_graph = onnx.shape_inference.infer_shapes(gs.export_onnx(graph))
|
extracted_graph = onnx.shape_inference.infer_shapes(
|
||||||
except Exception as e:
|
gs.export_onnx(graph)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
extracted_graph = gs.export_onnx(graph)
|
extracted_graph = gs.export_onnx(graph)
|
||||||
if not non_verbose:
|
if not non_verbose:
|
||||||
print(
|
print(
|
||||||
f'WARNING: ' +
|
"WARNING: "
|
||||||
'The input shape of the next OP does not match the output shape. ' +
|
+ "The input shape of the next OP does not match the output shape. "
|
||||||
'Be sure to open the .onnx file to verify the certainty of the geometry.'
|
+ "Be sure to open the .onnx file to verify the certainty of the geometry."
|
||||||
)
|
)
|
||||||
return extracted_graph
|
return extracted_graph
|
||||||
|
|
||||||
|
|
||||||
def extract_encoder(onnx_model: onnx.ModelProto):
|
def extract_encoder(onnx_model: onnx.ModelProto):
|
||||||
encoder_ = extract_sub_model(onnx_model, ["encoder/x", "encoder/x_lens"], ["encoder/encoder_out", "encoder/encoder_out_lens"], False)
|
encoder_ = extract_sub_model(
|
||||||
|
onnx_model,
|
||||||
|
["encoder/x", "encoder/x_lens"],
|
||||||
|
["encoder/encoder_out", "encoder/encoder_out_lens"],
|
||||||
|
False,
|
||||||
|
)
|
||||||
onnx.save(encoder_, "tmp_encoder.onnx")
|
onnx.save(encoder_, "tmp_encoder.onnx")
|
||||||
onnx.checker.check_model(encoder_)
|
onnx.checker.check_model(encoder_)
|
||||||
sess = onnxruntime.InferenceSession("tmp_encoder.onnx")
|
sess = onnxruntime.InferenceSession("tmp_encoder.onnx")
|
||||||
@ -191,7 +221,9 @@ def extract_encoder(onnx_model: onnx.ModelProto):
|
|||||||
|
|
||||||
|
|
||||||
def extract_decoder(onnx_model: onnx.ModelProto):
|
def extract_decoder(onnx_model: onnx.ModelProto):
|
||||||
decoder_ = extract_sub_model(onnx_model, ["decoder/y"], ["decoder/decoder_out"], False)
|
decoder_ = extract_sub_model(
|
||||||
|
onnx_model, ["decoder/y"], ["decoder/decoder_out"], False
|
||||||
|
)
|
||||||
onnx.save(decoder_, "tmp_decoder.onnx")
|
onnx.save(decoder_, "tmp_decoder.onnx")
|
||||||
onnx.checker.check_model(decoder_)
|
onnx.checker.check_model(decoder_)
|
||||||
sess = onnxruntime.InferenceSession("tmp_decoder.onnx")
|
sess = onnxruntime.InferenceSession("tmp_decoder.onnx")
|
||||||
@ -200,7 +232,12 @@ def extract_decoder(onnx_model: onnx.ModelProto):
|
|||||||
|
|
||||||
|
|
||||||
def extract_joiner(onnx_model: onnx.ModelProto):
|
def extract_joiner(onnx_model: onnx.ModelProto):
|
||||||
joiner_ = extract_sub_model(onnx_model, ["joiner/encoder_out", "joiner/decoder_out"], ["joiner/logit"], False)
|
joiner_ = extract_sub_model(
|
||||||
|
onnx_model,
|
||||||
|
["joiner/encoder_out", "joiner/decoder_out"],
|
||||||
|
["joiner/logit"],
|
||||||
|
False,
|
||||||
|
)
|
||||||
onnx.save(joiner_, "tmp_joiner.onnx")
|
onnx.save(joiner_, "tmp_joiner.onnx")
|
||||||
onnx.checker.check_model(joiner_)
|
onnx.checker.check_model(joiner_)
|
||||||
sess = onnxruntime.InferenceSession("tmp_joiner.onnx")
|
sess = onnxruntime.InferenceSession("tmp_joiner.onnx")
|
||||||
|
|||||||
Reference in New Issue
Block a user