fix CI tests

This commit is contained in:
Fangjun Kuang 2023-07-25 13:16:35 +08:00
parent 2ce61e6e7f
commit 721c8e95a5

View File

@ -265,7 +265,7 @@ def test_zipformer_encoder():
torch.onnx.export( torch.onnx.export(
encoder, encoder,
(x), (x, torch.ones(1, dtype=torch.float32)),
filename, filename,
verbose=False, verbose=False,
opset_version=opset_version, opset_version=opset_version,
@ -289,6 +289,7 @@ def test_zipformer_encoder():
input_nodes = session.get_inputs() input_nodes = session.get_inputs()
inputs = { inputs = {
input_nodes[0].name: x.numpy(), input_nodes[0].name: x.numpy(),
input_nodes[1].name: torch.ones(1, dtype=torch.float32).numpy(),
} }
onnx_y = session.run(["y"], inputs)[0] onnx_y = session.run(["y"], inputs)[0]
onnx_y = torch.from_numpy(onnx_y) onnx_y = torch.from_numpy(onnx_y)