fix input/output names
This commit is contained in:
parent
40126355ec
commit
19d5435afc
@ -61,10 +61,10 @@ def test_encoder(
|
||||
encoder_session: ort.InferenceSession,
|
||||
):
|
||||
encoder_inputs = encoder_session.get_inputs()
|
||||
assert encoder_inputs[0].name == "encoder/x"
|
||||
assert encoder_inputs[1].name == "encoder/x_lens"
|
||||
assert encoder_inputs[0].shape == ["N", "T", 80]
|
||||
assert encoder_inputs[1].shape == ["N"]
|
||||
encoder_input_names = [i.name for i in encoder_inputs]
|
||||
encoder_output_names = [i.name for i in encoder_session.get_outputs()]
|
||||
|
||||
for N in [1, 5]:
|
||||
for T in [12, 25]:
|
||||
@ -74,11 +74,11 @@ def test_encoder(
|
||||
x_lens[0] = T
|
||||
|
||||
encoder_inputs = {
|
||||
"encoder/x": x.numpy(),
|
||||
"encoder/x_lens": x_lens.numpy(),
|
||||
encoder_input_names[0]: x.numpy(),
|
||||
encoder_input_names[1]: x_lens.numpy(),
|
||||
}
|
||||
encoder_out, encoder_out_lens = encoder_session.run(
|
||||
["encoder/encoder_out", "encoder/encoder_out_lens"],
|
||||
[encoder_output_names[1], encoder_output_names[0]],
|
||||
encoder_inputs,
|
||||
)
|
||||
|
||||
@ -95,14 +95,16 @@ def test_decoder(
|
||||
decoder_session: ort.InferenceSession,
|
||||
):
|
||||
decoder_inputs = decoder_session.get_inputs()
|
||||
assert decoder_inputs[0].name == "decoder/y"
|
||||
assert decoder_inputs[0].shape == ["N", 2]
|
||||
decoder_input_names = [i.name for i in decoder_inputs]
|
||||
decoder_output_names = [i.name for i in decoder_session.get_outputs()]
|
||||
|
||||
for N in [1, 5, 10]:
|
||||
y = torch.randint(low=1, high=500, size=(10, 2))
|
||||
|
||||
decoder_inputs = {"decoder/y": y.numpy()}
|
||||
decoder_inputs = {decoder_input_names[0]: y.numpy()}
|
||||
decoder_out = decoder_session.run(
|
||||
["decoder/decoder_out"],
|
||||
[decoder_output_names[0]],
|
||||
decoder_inputs,
|
||||
)[0]
|
||||
decoder_out = torch.from_numpy(decoder_out)
|
||||
@ -118,21 +120,22 @@ def test_joiner(
|
||||
joiner_session: ort.InferenceSession,
|
||||
):
|
||||
joiner_inputs = joiner_session.get_inputs()
|
||||
assert joiner_inputs[0].name == "joiner/encoder_out"
|
||||
assert joiner_inputs[0].shape == ["N", 512]
|
||||
|
||||
assert joiner_inputs[1].name == "joiner/decoder_out"
|
||||
assert joiner_inputs[1].shape == ["N", 512]
|
||||
joiner_input_names = [i.name for i in joiner_inputs]
|
||||
joiner_output_names = [i.name for i in joiner_session.get_outputs()]
|
||||
|
||||
for N in [1, 5, 10]:
|
||||
encoder_out = torch.rand(N, 512)
|
||||
decoder_out = torch.rand(N, 512)
|
||||
|
||||
joiner_inputs = {
|
||||
"joiner/encoder_out": encoder_out.numpy(),
|
||||
"joiner/decoder_out": decoder_out.numpy(),
|
||||
joiner_input_names[0]: encoder_out.numpy(),
|
||||
joiner_input_names[1]: decoder_out.numpy(),
|
||||
}
|
||||
joiner_out = joiner_session.run(["joiner/logit"], joiner_inputs)[0]
|
||||
joiner_out = joiner_session.run(
|
||||
[joiner_output_names[0]], joiner_inputs
|
||||
)[0]
|
||||
joiner_out = torch.from_numpy(joiner_out)
|
||||
|
||||
torch_joiner_out = model.joiner(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user