mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
fix CI tests
This commit is contained in:
parent
2ce61e6e7f
commit
721c8e95a5
@ -265,7 +265,7 @@ def test_zipformer_encoder():
|
|||||||
|
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
encoder,
|
encoder,
|
||||||
(x),
|
(x, torch.ones(1, dtype=torch.float32)),
|
||||||
filename,
|
filename,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
opset_version=opset_version,
|
opset_version=opset_version,
|
||||||
@ -289,6 +289,7 @@ def test_zipformer_encoder():
|
|||||||
input_nodes = session.get_inputs()
|
input_nodes = session.get_inputs()
|
||||||
inputs = {
|
inputs = {
|
||||||
input_nodes[0].name: x.numpy(),
|
input_nodes[0].name: x.numpy(),
|
||||||
|
input_nodes[1].name: torch.ones(1, dtype=torch.float32).numpy(),
|
||||||
}
|
}
|
||||||
onnx_y = session.run(["y"], inputs)[0]
|
onnx_y = session.run(["y"], inputs)[0]
|
||||||
onnx_y = torch.from_numpy(onnx_y)
|
onnx_y = torch.from_numpy(onnx_y)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user