diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index bd1dc0e20..0fd6d41b3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -34,7 +34,7 @@ It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later load it by `torch.jit.load("cpu_jit.pt")`. Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python -are on CPU. You can use `to("cuda")` to move it to a CUDA device. +are on CPU. You can use `to("cuda")` to move them to a CUDA device. (2) Export to ONNX format @@ -51,10 +51,7 @@ See `onnx_check.py` to see how to use it. - encoder.onnx - decoder.onnx - joiner.onnx - - all_in_one.onnx -The file all_in_one.onnx combines `encoder.onnx`, `decoder.onnx`, and -`joiner.onnx`. You can use `onnx.utils.Extractor` to extract them. (3) Export `model.state_dict()` @@ -427,25 +424,6 @@ def main(): joiner_filename, opset_version=opset_version, ) - - all_in_one_filename = params.exp_dir / "all_in_one.onnx" - encoder_onnx = onnx.load(encoder_filename) - decoder_onnx = onnx.load(decoder_filename) - joiner_onnx = onnx.load(joiner_filename) - - encoder_onnx = onnx.compose.add_prefix(encoder_onnx, prefix="encoder/") - decoder_onnx = onnx.compose.add_prefix(decoder_onnx, prefix="decoder/") - joiner_onnx = onnx.compose.add_prefix(joiner_onnx, prefix="joiner/") - - combined_model = onnx.compose.merge_models( - encoder_onnx, decoder_onnx, io_map={} - ) - combined_model = onnx.compose.merge_models( - combined_model, joiner_onnx, io_map={} - ) - onnx.save(combined_model, all_in_one_filename) - logging.info(f"Saved to {all_in_one_filename}") - elif params.jit is True: # We won't use the forward() method of the model in C++, so just ignore # it here. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py index df0d0f09e..e10caf62e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py @@ -17,18 +17,18 @@ # limitations under the License. """ -This script checks that exported onnx models produces the same output -as the given torchscript model for the same input. +This script checks that exported onnx models produce the same output +with the given torchscript model for the same input. """ import argparse import logging + import onnxruntime as ort ort.set_default_logger_severity(3) import numpy as np - import torch @@ -85,7 +85,10 @@ def test_encoder( x_lens = torch.randint(low=10, high=T + 1, size=(N,)) x_lens[0] = T - encoder_inputs = {"x": x.numpy(), "x_lens": x_lens.numpy()} + encoder_inputs = { + "x": x.numpy(), + "x_lens": x_lens.numpy(), + } encoder_out, encoder_out_lens = encoder_session.run( ["encoder_out", "encoder_out_lens"], encoder_inputs, @@ -141,7 +144,7 @@ def test_joiner( "encoder_out": encoder_out.numpy(), "decoder_out": decoder_out.numpy(), } - decoder_out = joiner_session.run(["logit"], joiner_inputs)[0] + joiner_out = joiner_session.run(["logit"], joiner_inputs)[0] joiner_out = torch.from_numpy(joiner_out) torch_joiner_out = model.joiner(