From 19d5435afcf02446fb972aeead53af199d7643b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yunus=20Emre=20=C3=96zk=C3=B6se?= Date: Thu, 4 Aug 2022 09:11:21 +0300 Subject: [PATCH] fix input/output names --- .../onnx_check_all_in_one.py | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check_all_in_one.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check_all_in_one.py index d9a23e1b6..b4cf8c94a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check_all_in_one.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check_all_in_one.py @@ -61,10 +61,10 @@ def test_encoder( encoder_session: ort.InferenceSession, ): encoder_inputs = encoder_session.get_inputs() - assert encoder_inputs[0].name == "encoder/x" - assert encoder_inputs[1].name == "encoder/x_lens" assert encoder_inputs[0].shape == ["N", "T", 80] assert encoder_inputs[1].shape == ["N"] + encoder_input_names = [i.name for i in encoder_inputs] + encoder_output_names = [i.name for i in encoder_session.get_outputs()] for N in [1, 5]: for T in [12, 25]: @@ -74,11 +74,11 @@ def test_encoder( x_lens[0] = T encoder_inputs = { - "encoder/x": x.numpy(), - "encoder/x_lens": x_lens.numpy(), + encoder_input_names[0]: x.numpy(), + encoder_input_names[1]: x_lens.numpy(), } encoder_out, encoder_out_lens = encoder_session.run( - ["encoder/encoder_out", "encoder/encoder_out_lens"], + [encoder_output_names[1], encoder_output_names[0]], encoder_inputs, ) @@ -95,14 +95,16 @@ def test_decoder( decoder_session: ort.InferenceSession, ): decoder_inputs = decoder_session.get_inputs() - assert decoder_inputs[0].name == "decoder/y" assert decoder_inputs[0].shape == ["N", 2] + decoder_input_names = [i.name for i in decoder_inputs] + decoder_output_names = [i.name for i in decoder_session.get_outputs()] + for N in [1, 5, 10]: y = torch.randint(low=1, high=500, size=(10, 2)) - decoder_inputs = {"decoder/y": y.numpy()} + decoder_inputs = {decoder_input_names[0]: y.numpy()} decoder_out = decoder_session.run( - ["decoder/decoder_out"], + [decoder_output_names[0]], decoder_inputs, )[0] decoder_out = torch.from_numpy(decoder_out) @@ -118,21 +120,22 @@ def test_joiner( joiner_session: ort.InferenceSession, ): joiner_inputs = joiner_session.get_inputs() - assert joiner_inputs[0].name == "joiner/encoder_out" assert joiner_inputs[0].shape == ["N", 512] - - assert joiner_inputs[1].name == "joiner/decoder_out" assert joiner_inputs[1].shape == ["N", 512] + joiner_input_names = [i.name for i in joiner_inputs] + joiner_output_names = [i.name for i in joiner_session.get_outputs()] for N in [1, 5, 10]: encoder_out = torch.rand(N, 512) decoder_out = torch.rand(N, 512) joiner_inputs = { - "joiner/encoder_out": encoder_out.numpy(), - "joiner/decoder_out": decoder_out.numpy(), + joiner_input_names[0]: encoder_out.numpy(), + joiner_input_names[1]: decoder_out.numpy(), } - joiner_out = joiner_session.run(["joiner/logit"], joiner_inputs)[0] + joiner_out = joiner_session.run( + [joiner_output_names[0]], joiner_inputs + )[0] joiner_out = torch.from_numpy(joiner_out) torch_joiner_out = model.joiner(