mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Fix tests
This commit is contained in:
parent
25641b2ead
commit
2ce61e6e7f
@ -849,6 +849,8 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
||||
|
||||
"""
|
||||
if isinstance(left_context, torch.Tensor):
|
||||
left_context = left_context.item()
|
||||
self.extend_pe(x, left_context)
|
||||
x_size_1 = x.size(1) + left_context
|
||||
pos_emb = self.pe[
|
||||
|
@ -113,7 +113,7 @@ def test_rel_pos():
|
||||
|
||||
torch.onnx.export(
|
||||
encoder_pos,
|
||||
x,
|
||||
(x, torch.zeros(1, dtype=torch.int64)),
|
||||
filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
@ -139,7 +139,9 @@ def test_rel_pos():
|
||||
assert input_nodes[0].name == "x"
|
||||
assert input_nodes[0].shape == ["N", "T", num_features]
|
||||
|
||||
inputs = {input_nodes[0].name: x.numpy()}
|
||||
inputs = {
|
||||
input_nodes[0].name: x.numpy(),
|
||||
}
|
||||
onnx_y, onnx_pos_emb = session.run(["y", "pos_emb"], inputs)
|
||||
onnx_y = torch.from_numpy(onnx_y)
|
||||
onnx_pos_emb = torch.from_numpy(onnx_pos_emb)
|
||||
|
Loading…
x
Reference in New Issue
Block a user