fix input/output names

This commit is contained in:
Yunus Emre Özköse 2022-08-04 09:11:21 +03:00
parent 40126355ec
commit 19d5435afc

View File

@ -61,10 +61,10 @@ def test_encoder(
encoder_session: ort.InferenceSession,
):
encoder_inputs = encoder_session.get_inputs()
assert encoder_inputs[0].name == "encoder/x"
assert encoder_inputs[1].name == "encoder/x_lens"
assert encoder_inputs[0].shape == ["N", "T", 80]
assert encoder_inputs[1].shape == ["N"]
encoder_input_names = [i.name for i in encoder_inputs]
encoder_output_names = [i.name for i in encoder_session.get_outputs()]
for N in [1, 5]:
for T in [12, 25]:
@ -74,11 +74,11 @@ def test_encoder(
x_lens[0] = T
encoder_inputs = {
"encoder/x": x.numpy(),
"encoder/x_lens": x_lens.numpy(),
encoder_input_names[0]: x.numpy(),
encoder_input_names[1]: x_lens.numpy(),
}
encoder_out, encoder_out_lens = encoder_session.run(
["encoder/encoder_out", "encoder/encoder_out_lens"],
[encoder_output_names[1], encoder_output_names[0]],
encoder_inputs,
)
@ -95,14 +95,16 @@ def test_decoder(
decoder_session: ort.InferenceSession,
):
decoder_inputs = decoder_session.get_inputs()
assert decoder_inputs[0].name == "decoder/y"
assert decoder_inputs[0].shape == ["N", 2]
decoder_input_names = [i.name for i in decoder_inputs]
decoder_output_names = [i.name for i in decoder_session.get_outputs()]
for N in [1, 5, 10]:
y = torch.randint(low=1, high=500, size=(10, 2))
decoder_inputs = {"decoder/y": y.numpy()}
decoder_inputs = {decoder_input_names[0]: y.numpy()}
decoder_out = decoder_session.run(
["decoder/decoder_out"],
[decoder_output_names[0]],
decoder_inputs,
)[0]
decoder_out = torch.from_numpy(decoder_out)
@ -118,21 +120,22 @@ def test_joiner(
joiner_session: ort.InferenceSession,
):
joiner_inputs = joiner_session.get_inputs()
assert joiner_inputs[0].name == "joiner/encoder_out"
assert joiner_inputs[0].shape == ["N", 512]
assert joiner_inputs[1].name == "joiner/decoder_out"
assert joiner_inputs[1].shape == ["N", 512]
joiner_input_names = [i.name for i in joiner_inputs]
joiner_output_names = [i.name for i in joiner_session.get_outputs()]
for N in [1, 5, 10]:
encoder_out = torch.rand(N, 512)
decoder_out = torch.rand(N, 512)
joiner_inputs = {
"joiner/encoder_out": encoder_out.numpy(),
"joiner/decoder_out": decoder_out.numpy(),
joiner_input_names[0]: encoder_out.numpy(),
joiner_input_names[1]: decoder_out.numpy(),
}
joiner_out = joiner_session.run(["joiner/logit"], joiner_inputs)[0]
joiner_out = joiner_session.run(
[joiner_output_names[0]], joiner_inputs
)[0]
joiner_out = torch.from_numpy(joiner_out)
torch_joiner_out = model.joiner(