Fix tests

This commit is contained in:
Fangjun Kuang 2023-07-25 12:44:49 +08:00
parent 25641b2ead
commit 2ce61e6e7f
2 changed files with 6 additions and 2 deletions

View File

@ -849,6 +849,8 @@ class RelPositionalEncoding(torch.nn.Module):
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
"""
if isinstance(left_context, torch.Tensor):
left_context = left_context.item()
self.extend_pe(x, left_context)
x_size_1 = x.size(1) + left_context
pos_emb = self.pe[

View File

@ -113,7 +113,7 @@ def test_rel_pos():
torch.onnx.export(
encoder_pos,
x,
(x, torch.zeros(1, dtype=torch.int64)),
filename,
verbose=False,
opset_version=opset_version,
@ -139,7 +139,9 @@ def test_rel_pos():
assert input_nodes[0].name == "x"
assert input_nodes[0].shape == ["N", "T", num_features]
inputs = {input_nodes[0].name: x.numpy()}
inputs = {
input_nodes[0].name: x.numpy(),
}
onnx_y, onnx_pos_emb = session.run(["y", "pos_emb"], inputs)
onnx_y = torch.from_numpy(onnx_y)
onnx_pos_emb = torch.from_numpy(onnx_pos_emb)