mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Revert merging three onnx models into a single one.
It's quite time consuming to extract a sub-graph from the combined model. For instance, it takes more than one hour to extract the encoder model.
This commit is contained in:
parent
3ebb52aa9b
commit
c70df281c6
@ -34,7 +34,7 @@ It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
|
|||||||
load it by `torch.jit.load("cpu_jit.pt")`.
|
load it by `torch.jit.load("cpu_jit.pt")`.
|
||||||
|
|
||||||
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
|
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
|
||||||
are on CPU. You can use `to("cuda")` to move it to a CUDA device.
|
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
|
||||||
|
|
||||||
(2) Export to ONNX format
|
(2) Export to ONNX format
|
||||||
|
|
||||||
@ -51,10 +51,7 @@ See `onnx_check.py` to see how to use it.
|
|||||||
- encoder.onnx
|
- encoder.onnx
|
||||||
- decoder.onnx
|
- decoder.onnx
|
||||||
- joiner.onnx
|
- joiner.onnx
|
||||||
- all_in_one.onnx
|
|
||||||
|
|
||||||
The file all_in_one.onnx combines `encoder.onnx`, `decoder.onnx`, and
|
|
||||||
`joiner.onnx`. You can use `onnx.utils.Extractor` to extract them.
|
|
||||||
|
|
||||||
(3) Export `model.state_dict()`
|
(3) Export `model.state_dict()`
|
||||||
|
|
||||||
@ -427,25 +424,6 @@ def main():
|
|||||||
joiner_filename,
|
joiner_filename,
|
||||||
opset_version=opset_version,
|
opset_version=opset_version,
|
||||||
)
|
)
|
||||||
|
|
||||||
all_in_one_filename = params.exp_dir / "all_in_one.onnx"
|
|
||||||
encoder_onnx = onnx.load(encoder_filename)
|
|
||||||
decoder_onnx = onnx.load(decoder_filename)
|
|
||||||
joiner_onnx = onnx.load(joiner_filename)
|
|
||||||
|
|
||||||
encoder_onnx = onnx.compose.add_prefix(encoder_onnx, prefix="encoder/")
|
|
||||||
decoder_onnx = onnx.compose.add_prefix(decoder_onnx, prefix="decoder/")
|
|
||||||
joiner_onnx = onnx.compose.add_prefix(joiner_onnx, prefix="joiner/")
|
|
||||||
|
|
||||||
combined_model = onnx.compose.merge_models(
|
|
||||||
encoder_onnx, decoder_onnx, io_map={}
|
|
||||||
)
|
|
||||||
combined_model = onnx.compose.merge_models(
|
|
||||||
combined_model, joiner_onnx, io_map={}
|
|
||||||
)
|
|
||||||
onnx.save(combined_model, all_in_one_filename)
|
|
||||||
logging.info(f"Saved to {all_in_one_filename}")
|
|
||||||
|
|
||||||
elif params.jit is True:
|
elif params.jit is True:
|
||||||
# We won't use the forward() method of the model in C++, so just ignore
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
# it here.
|
# it here.
|
||||||
|
@ -17,18 +17,18 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This script checks that exported onnx models produces the same output
|
This script checks that exported onnx models produce the same output
|
||||||
as the given torchscript model for the same input.
|
with the given torchscript model for the same input.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
|
|
||||||
ort.set_default_logger_severity(3)
|
ort.set_default_logger_severity(3)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@ -85,7 +85,10 @@ def test_encoder(
|
|||||||
x_lens = torch.randint(low=10, high=T + 1, size=(N,))
|
x_lens = torch.randint(low=10, high=T + 1, size=(N,))
|
||||||
x_lens[0] = T
|
x_lens[0] = T
|
||||||
|
|
||||||
encoder_inputs = {"x": x.numpy(), "x_lens": x_lens.numpy()}
|
encoder_inputs = {
|
||||||
|
"x": x.numpy(),
|
||||||
|
"x_lens": x_lens.numpy(),
|
||||||
|
}
|
||||||
encoder_out, encoder_out_lens = encoder_session.run(
|
encoder_out, encoder_out_lens = encoder_session.run(
|
||||||
["encoder_out", "encoder_out_lens"],
|
["encoder_out", "encoder_out_lens"],
|
||||||
encoder_inputs,
|
encoder_inputs,
|
||||||
@ -141,7 +144,7 @@ def test_joiner(
|
|||||||
"encoder_out": encoder_out.numpy(),
|
"encoder_out": encoder_out.numpy(),
|
||||||
"decoder_out": decoder_out.numpy(),
|
"decoder_out": decoder_out.numpy(),
|
||||||
}
|
}
|
||||||
decoder_out = joiner_session.run(["logit"], joiner_inputs)[0]
|
joiner_out = joiner_session.run(["logit"], joiner_inputs)[0]
|
||||||
joiner_out = torch.from_numpy(joiner_out)
|
joiner_out = torch.from_numpy(joiner_out)
|
||||||
|
|
||||||
torch_joiner_out = model.joiner(
|
torch_joiner_out = model.joiner(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user