Revert merging three onnx models into a single one.

It's quite time consuming to extract a sub-graph from the combined
model. For instance, it takes more than one hour to extract
the encoder model.
This commit is contained in:
Fangjun Kuang 2022-07-30 11:33:22 +08:00
parent 3ebb52aa9b
commit c70df281c6
2 changed files with 9 additions and 28 deletions

View File

@ -34,7 +34,7 @@ It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
load it by `torch.jit.load("cpu_jit.pt")`. load it by `torch.jit.load("cpu_jit.pt")`.
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
are on CPU. You can use `to("cuda")` to move it to a CUDA device. are on CPU. You can use `to("cuda")` to move them to a CUDA device.
(2) Export to ONNX format (2) Export to ONNX format
@ -51,10 +51,7 @@ See `onnx_check.py` to see how to use it.
- encoder.onnx - encoder.onnx
- decoder.onnx - decoder.onnx
- joiner.onnx - joiner.onnx
- all_in_one.onnx
The file all_in_one.onnx combines `encoder.onnx`, `decoder.onnx`, and
`joiner.onnx`. You can use `onnx.utils.Extractor` to extract them.
(3) Export `model.state_dict()` (3) Export `model.state_dict()`
@ -427,25 +424,6 @@ def main():
joiner_filename, joiner_filename,
opset_version=opset_version, opset_version=opset_version,
) )
all_in_one_filename = params.exp_dir / "all_in_one.onnx"
encoder_onnx = onnx.load(encoder_filename)
decoder_onnx = onnx.load(decoder_filename)
joiner_onnx = onnx.load(joiner_filename)
encoder_onnx = onnx.compose.add_prefix(encoder_onnx, prefix="encoder/")
decoder_onnx = onnx.compose.add_prefix(decoder_onnx, prefix="decoder/")
joiner_onnx = onnx.compose.add_prefix(joiner_onnx, prefix="joiner/")
combined_model = onnx.compose.merge_models(
encoder_onnx, decoder_onnx, io_map={}
)
combined_model = onnx.compose.merge_models(
combined_model, joiner_onnx, io_map={}
)
onnx.save(combined_model, all_in_one_filename)
logging.info(f"Saved to {all_in_one_filename}")
elif params.jit is True: elif params.jit is True:
# We won't use the forward() method of the model in C++, so just ignore # We won't use the forward() method of the model in C++, so just ignore
# it here. # it here.

View File

@ -17,18 +17,18 @@
# limitations under the License. # limitations under the License.
""" """
This script checks that exported onnx models produces the same output This script checks that exported onnx models produce the same output
as the given torchscript model for the same input. with the given torchscript model for the same input.
""" """
import argparse import argparse
import logging import logging
import onnxruntime as ort import onnxruntime as ort
ort.set_default_logger_severity(3) ort.set_default_logger_severity(3)
import numpy as np import numpy as np
import torch import torch
@ -85,7 +85,10 @@ def test_encoder(
x_lens = torch.randint(low=10, high=T + 1, size=(N,)) x_lens = torch.randint(low=10, high=T + 1, size=(N,))
x_lens[0] = T x_lens[0] = T
encoder_inputs = {"x": x.numpy(), "x_lens": x_lens.numpy()} encoder_inputs = {
"x": x.numpy(),
"x_lens": x_lens.numpy(),
}
encoder_out, encoder_out_lens = encoder_session.run( encoder_out, encoder_out_lens = encoder_session.run(
["encoder_out", "encoder_out_lens"], ["encoder_out", "encoder_out_lens"],
encoder_inputs, encoder_inputs,
@ -141,7 +144,7 @@ def test_joiner(
"encoder_out": encoder_out.numpy(), "encoder_out": encoder_out.numpy(),
"decoder_out": decoder_out.numpy(), "decoder_out": decoder_out.numpy(),
} }
decoder_out = joiner_session.run(["logit"], joiner_inputs)[0] joiner_out = joiner_session.run(["logit"], joiner_inputs)[0]
joiner_out = torch.from_numpy(joiner_out) joiner_out = torch.from_numpy(joiner_out)
torch_joiner_out = model.joiner( torch_joiner_out = model.joiner(