mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Revert merging three onnx models into a single one.
It's quite time consuming to extract a sub-graph from the combined model. For instance, it takes more than one hour to extract the encoder model.
This commit is contained in:
parent
3ebb52aa9b
commit
c70df281c6
@ -34,7 +34,7 @@ It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
|
||||
load it by `torch.jit.load("cpu_jit.pt")`.
|
||||
|
||||
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
|
||||
are on CPU. You can use `to("cuda")` to move it to a CUDA device.
|
||||
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
|
||||
|
||||
(2) Export to ONNX format
|
||||
|
||||
@ -51,10 +51,7 @@ See `onnx_check.py` to see how to use it.
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
- joiner.onnx
|
||||
- all_in_one.onnx
|
||||
|
||||
The file all_in_one.onnx combines `encoder.onnx`, `decoder.onnx`, and
|
||||
`joiner.onnx`. You can use `onnx.utils.Extractor` to extract them.
|
||||
|
||||
(3) Export `model.state_dict()`
|
||||
|
||||
@ -427,25 +424,6 @@ def main():
|
||||
joiner_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
all_in_one_filename = params.exp_dir / "all_in_one.onnx"
|
||||
encoder_onnx = onnx.load(encoder_filename)
|
||||
decoder_onnx = onnx.load(decoder_filename)
|
||||
joiner_onnx = onnx.load(joiner_filename)
|
||||
|
||||
encoder_onnx = onnx.compose.add_prefix(encoder_onnx, prefix="encoder/")
|
||||
decoder_onnx = onnx.compose.add_prefix(decoder_onnx, prefix="decoder/")
|
||||
joiner_onnx = onnx.compose.add_prefix(joiner_onnx, prefix="joiner/")
|
||||
|
||||
combined_model = onnx.compose.merge_models(
|
||||
encoder_onnx, decoder_onnx, io_map={}
|
||||
)
|
||||
combined_model = onnx.compose.merge_models(
|
||||
combined_model, joiner_onnx, io_map={}
|
||||
)
|
||||
onnx.save(combined_model, all_in_one_filename)
|
||||
logging.info(f"Saved to {all_in_one_filename}")
|
||||
|
||||
elif params.jit is True:
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
|
@ -17,18 +17,18 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script checks that exported onnx models produces the same output
|
||||
as the given torchscript model for the same input.
|
||||
This script checks that exported onnx models produce the same output
|
||||
with the given torchscript model for the same input.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import onnxruntime as ort
|
||||
|
||||
ort.set_default_logger_severity(3)
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@ -85,7 +85,10 @@ def test_encoder(
|
||||
x_lens = torch.randint(low=10, high=T + 1, size=(N,))
|
||||
x_lens[0] = T
|
||||
|
||||
encoder_inputs = {"x": x.numpy(), "x_lens": x_lens.numpy()}
|
||||
encoder_inputs = {
|
||||
"x": x.numpy(),
|
||||
"x_lens": x_lens.numpy(),
|
||||
}
|
||||
encoder_out, encoder_out_lens = encoder_session.run(
|
||||
["encoder_out", "encoder_out_lens"],
|
||||
encoder_inputs,
|
||||
@ -141,7 +144,7 @@ def test_joiner(
|
||||
"encoder_out": encoder_out.numpy(),
|
||||
"decoder_out": decoder_out.numpy(),
|
||||
}
|
||||
decoder_out = joiner_session.run(["logit"], joiner_inputs)[0]
|
||||
joiner_out = joiner_session.run(["logit"], joiner_inputs)[0]
|
||||
joiner_out = torch.from_numpy(joiner_out)
|
||||
|
||||
torch_joiner_out = model.joiner(
|
||||
|
Loading…
x
Reference in New Issue
Block a user