mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
fix style
This commit is contained in:
parent
d382455945
commit
d2de493cd8
@ -60,8 +60,9 @@ log "Decode with ONNX models"
|
||||
--onnx-decoder-filename $repo/exp/decoder.onnx \
|
||||
--onnx-joiner-filename $repo/exp/joiner.onnx
|
||||
|
||||
./pruned_transducer_stateless3/onnx_check_all_in_one.py --jit-filename $repo/exp/cpu_jit.pt \
|
||||
--onnx-all-in-one-filename $repo/exp/all_in_one.onnx
|
||||
./pruned_transducer_stateless3/onnx_check_all_in_one.py \
|
||||
--jit-filename $repo/exp/cpu_jit.pt \
|
||||
--onnx-all-in-one-filename $repo/exp/all_in_one.onnx
|
||||
|
||||
./pruned_transducer_stateless3/onnx_pretrained.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
|
@ -111,11 +111,11 @@ with the following commands:
|
||||
# You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp
|
||||
"""
|
||||
|
||||
import onnx
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import onnx
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -513,7 +513,12 @@ def export_joiner_model_onnx(
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
|
||||
def export_all_in_one_onnx(encoder_filename: str, decoder_filename: str, joiner_filename: str, all_in_one_filename: str):
|
||||
def export_all_in_one_onnx(
|
||||
encoder_filename: str,
|
||||
decoder_filename: str,
|
||||
joiner_filename: str,
|
||||
all_in_one_filename: str,
|
||||
):
|
||||
encoder_onnx = onnx.load(encoder_filename)
|
||||
decoder_onnx = onnx.load(decoder_filename)
|
||||
joiner_onnx = onnx.load(joiner_filename)
|
||||
@ -629,7 +634,7 @@ def main():
|
||||
encoder_filename,
|
||||
decoder_filename,
|
||||
joiner_filename,
|
||||
all_in_one_filename
|
||||
all_in_one_filename,
|
||||
)
|
||||
elif params.jit is True:
|
||||
logging.info("Using torch.jit.script()")
|
||||
|
@ -21,16 +21,15 @@ This script checks that exported onnx models produce the same output
|
||||
with the given torchscript model for the same input.
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import os
|
||||
|
||||
import onnx
|
||||
import onnxruntime
|
||||
import onnx_graphsurgeon as gs
|
||||
import onnxruntime
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
|
||||
ort.set_default_logger_severity(3)
|
||||
|
||||
@ -146,22 +145,46 @@ def test_joiner(
|
||||
)
|
||||
|
||||
|
||||
def extract_sub_model(onnx_graph: onnx.ModelProto, input_op_names: list, output_op_names: list, non_verbose=False):
|
||||
def extract_sub_model(
|
||||
onnx_graph: onnx.ModelProto,
|
||||
input_op_names: list,
|
||||
output_op_names: list,
|
||||
non_verbose=False,
|
||||
):
|
||||
onnx_graph = onnx.shape_inference.infer_shapes(onnx_graph)
|
||||
graph = gs.import_onnx(onnx_graph)
|
||||
graph.cleanup().toposort()
|
||||
|
||||
# Extraction of input OP and output OP
|
||||
graph_node_inputs = [graph_nodes for graph_nodes in graph.nodes for graph_nodes_input in graph_nodes.inputs if graph_nodes_input.name in input_op_names]
|
||||
graph_node_outputs = [graph_nodes for graph_nodes in graph.nodes for graph_nodes_output in graph_nodes.outputs if graph_nodes_output.name in output_op_names]
|
||||
graph_node_inputs = [
|
||||
graph_nodes
|
||||
for graph_nodes in graph.nodes
|
||||
for graph_nodes_input in graph_nodes.inputs
|
||||
if graph_nodes_input.name in input_op_names
|
||||
]
|
||||
graph_node_outputs = [
|
||||
graph_nodes
|
||||
for graph_nodes in graph.nodes
|
||||
for graph_nodes_output in graph_nodes.outputs
|
||||
if graph_nodes_output.name in output_op_names
|
||||
]
|
||||
|
||||
# Init graph INPUT/OUTPUT
|
||||
graph.inputs.clear()
|
||||
graph.outputs.clear()
|
||||
|
||||
# Update graph INPUT/OUTPUT
|
||||
graph.inputs = [graph_node_input for graph_node in graph_node_inputs for graph_node_input in graph_node.inputs if graph_node_input.shape]
|
||||
graph.outputs = [graph_node_output for graph_node in graph_node_outputs for graph_node_output in graph_node.outputs]
|
||||
graph.inputs = [
|
||||
graph_node_input
|
||||
for graph_node in graph_node_inputs
|
||||
for graph_node_input in graph_node.inputs
|
||||
if graph_node_input.shape
|
||||
]
|
||||
graph.outputs = [
|
||||
graph_node_output
|
||||
for graph_node in graph_node_outputs
|
||||
for graph_node_output in graph_node.outputs
|
||||
]
|
||||
|
||||
# Cleanup
|
||||
graph.cleanup().toposort()
|
||||
@ -169,20 +192,27 @@ def extract_sub_model(onnx_graph: onnx.ModelProto, input_op_names: list, output_
|
||||
# Shape Estimation
|
||||
extracted_graph = None
|
||||
try:
|
||||
extracted_graph = onnx.shape_inference.infer_shapes(gs.export_onnx(graph))
|
||||
except Exception as e:
|
||||
extracted_graph = onnx.shape_inference.infer_shapes(
|
||||
gs.export_onnx(graph)
|
||||
)
|
||||
except Exception:
|
||||
extracted_graph = gs.export_onnx(graph)
|
||||
if not non_verbose:
|
||||
print(
|
||||
f'WARNING: ' +
|
||||
'The input shape of the next OP does not match the output shape. ' +
|
||||
'Be sure to open the .onnx file to verify the certainty of the geometry.'
|
||||
"WARNING: "
|
||||
+ "The input shape of the next OP does not match the output shape. "
|
||||
+ "Be sure to open the .onnx file to verify the certainty of the geometry."
|
||||
)
|
||||
return extracted_graph
|
||||
|
||||
|
||||
def extract_encoder(onnx_model: onnx.ModelProto):
|
||||
encoder_ = extract_sub_model(onnx_model, ["encoder/x", "encoder/x_lens"], ["encoder/encoder_out", "encoder/encoder_out_lens"], False)
|
||||
encoder_ = extract_sub_model(
|
||||
onnx_model,
|
||||
["encoder/x", "encoder/x_lens"],
|
||||
["encoder/encoder_out", "encoder/encoder_out_lens"],
|
||||
False,
|
||||
)
|
||||
onnx.save(encoder_, "tmp_encoder.onnx")
|
||||
onnx.checker.check_model(encoder_)
|
||||
sess = onnxruntime.InferenceSession("tmp_encoder.onnx")
|
||||
@ -191,7 +221,9 @@ def extract_encoder(onnx_model: onnx.ModelProto):
|
||||
|
||||
|
||||
def extract_decoder(onnx_model: onnx.ModelProto):
|
||||
decoder_ = extract_sub_model(onnx_model, ["decoder/y"], ["decoder/decoder_out"], False)
|
||||
decoder_ = extract_sub_model(
|
||||
onnx_model, ["decoder/y"], ["decoder/decoder_out"], False
|
||||
)
|
||||
onnx.save(decoder_, "tmp_decoder.onnx")
|
||||
onnx.checker.check_model(decoder_)
|
||||
sess = onnxruntime.InferenceSession("tmp_decoder.onnx")
|
||||
@ -200,7 +232,12 @@ def extract_decoder(onnx_model: onnx.ModelProto):
|
||||
|
||||
|
||||
def extract_joiner(onnx_model: onnx.ModelProto):
|
||||
joiner_ = extract_sub_model(onnx_model, ["joiner/encoder_out", "joiner/decoder_out"], ["joiner/logit"], False)
|
||||
joiner_ = extract_sub_model(
|
||||
onnx_model,
|
||||
["joiner/encoder_out", "joiner/decoder_out"],
|
||||
["joiner/logit"],
|
||||
False,
|
||||
)
|
||||
onnx.save(joiner_, "tmp_joiner.onnx")
|
||||
onnx.checker.check_model(joiner_)
|
||||
sess = onnxruntime.InferenceSession("tmp_joiner.onnx")
|
||||
@ -223,7 +260,7 @@ def main():
|
||||
logging.info("Test encoder")
|
||||
encoder_session = extract_encoder(onnx_model)
|
||||
test_encoder(model, encoder_session)
|
||||
|
||||
|
||||
logging.info("Test decoder")
|
||||
decoder_session = extract_decoder(onnx_model)
|
||||
test_decoder(model, decoder_session)
|
||||
|
Loading…
x
Reference in New Issue
Block a user