fix CI tests

This commit is contained in:
Fangjun Kuang 2023-07-25 13:16:35 +08:00
parent 2ce61e6e7f
commit 721c8e95a5

View File

@ -265,7 +265,7 @@ def test_zipformer_encoder():
torch.onnx.export(
encoder,
(x),
(x, torch.ones(1, dtype=torch.float32)),
filename,
verbose=False,
opset_version=opset_version,
@ -289,6 +289,7 @@ def test_zipformer_encoder():
input_nodes = session.get_inputs()
inputs = {
input_nodes[0].name: x.numpy(),
input_nodes[1].name: torch.ones(1, dtype=torch.float32).numpy(),
}
onnx_y = session.run(["y"], inputs)[0]
onnx_y = torch.from_numpy(onnx_y)