mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Fix tests
This commit is contained in:
parent
25641b2ead
commit
2ce61e6e7f
@ -849,6 +849,8 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if isinstance(left_context, torch.Tensor):
|
||||||
|
left_context = left_context.item()
|
||||||
self.extend_pe(x, left_context)
|
self.extend_pe(x, left_context)
|
||||||
x_size_1 = x.size(1) + left_context
|
x_size_1 = x.size(1) + left_context
|
||||||
pos_emb = self.pe[
|
pos_emb = self.pe[
|
||||||
|
@ -113,7 +113,7 @@ def test_rel_pos():
|
|||||||
|
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
encoder_pos,
|
encoder_pos,
|
||||||
x,
|
(x, torch.zeros(1, dtype=torch.int64)),
|
||||||
filename,
|
filename,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
opset_version=opset_version,
|
opset_version=opset_version,
|
||||||
@ -139,7 +139,9 @@ def test_rel_pos():
|
|||||||
assert input_nodes[0].name == "x"
|
assert input_nodes[0].name == "x"
|
||||||
assert input_nodes[0].shape == ["N", "T", num_features]
|
assert input_nodes[0].shape == ["N", "T", num_features]
|
||||||
|
|
||||||
inputs = {input_nodes[0].name: x.numpy()}
|
inputs = {
|
||||||
|
input_nodes[0].name: x.numpy(),
|
||||||
|
}
|
||||||
onnx_y, onnx_pos_emb = session.run(["y", "pos_emb"], inputs)
|
onnx_y, onnx_pos_emb = session.run(["y", "pos_emb"], inputs)
|
||||||
onnx_y = torch.from_numpy(onnx_y)
|
onnx_y = torch.from_numpy(onnx_y)
|
||||||
onnx_pos_emb = torch.from_numpy(onnx_pos_emb)
|
onnx_pos_emb = torch.from_numpy(onnx_pos_emb)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user