test RelPositionalEncoding

This commit is contained in:
Fangjun Kuang 2022-08-09 23:32:28 +08:00
parent 10360bed41
commit 317f47fb37
2 changed files with 88 additions and 0 deletions

View File

@ -0,0 +1,49 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
from conformer import RelPositionalEncoding
from scaling_converter import convert_scaled_to_non_scaled
class Foo(nn.Module):
def __init__(self):
super().__init__()
d_model = 512
dropout = 0.1
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
def forward(self, x: torch.Tensor):
"""
Args:
x:
(N, T, C)
"""
y, pos_emb = self.encoder_pos(x)
return y, pos_emb
def generate_pt():
f = Foo()
f.eval()
f = convert_scaled_to_non_scaled(f)
x = torch.rand(1, 6, 4)
y, pos_emb = f(x)
print("y.shape", y.shape)
print("pos_emb.shape", pos_emb.shape)
m = torch.jit.trace(f, x)
m.save("foo/encoder_pos.pt")
print(m.encoder_pos.pe[0].shape)
print(type(m.encoder_pos.pe[0]))
print(m.graph)
def main():
generate_pt()
if __name__ == "__main__":
torch.manual_seed(20220809)
main()

View File

@ -0,0 +1,39 @@
#!/usr/bin/env python3
import ncnn
import numpy as np
import torch
@torch.no_grad()
def main():
x = torch.rand(100, 512) # (T, C)
m = torch.jit.load("foo/encoder_pos.pt")
_, t = m(x.unsqueeze(0)) # bach size is 1
t = t.squeeze(0) # (T, C)
param = "foo/encoder_pos.ncnn.param"
model = "foo/encoder_pos.ncnn.bin"
with ncnn.Net() as net:
net.load_param(param)
net.load_model(model)
with net.create_extractor() as ex:
ex.input("in0", ncnn.Mat(x.numpy()).clone())
ret, ncnn_out0 = ex.extract("out1")
assert ret == 0, ret
n = np.array(ncnn_out0)
print(n.shape) # (6, 512), (T, C)
n = torch.from_numpy(n)
print(t.reshape(-1)[:10])
print(n.reshape(-1)[:10])
assert torch.allclose(t, n), (t - n).abs().max()
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
torch.manual_seed(20220808)
main()