mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
test RelPositionalEncoding
This commit is contained in:
parent
10360bed41
commit
317f47fb37
@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from conformer import RelPositionalEncoding
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
|
||||
|
||||
class Foo(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
d_model = 512
|
||||
dropout = 0.1
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
(N, T, C)
|
||||
"""
|
||||
y, pos_emb = self.encoder_pos(x)
|
||||
return y, pos_emb
|
||||
|
||||
|
||||
def generate_pt():
|
||||
f = Foo()
|
||||
f.eval()
|
||||
f = convert_scaled_to_non_scaled(f)
|
||||
x = torch.rand(1, 6, 4)
|
||||
y, pos_emb = f(x)
|
||||
print("y.shape", y.shape)
|
||||
print("pos_emb.shape", pos_emb.shape)
|
||||
m = torch.jit.trace(f, x)
|
||||
m.save("foo/encoder_pos.pt")
|
||||
print(m.encoder_pos.pe[0].shape)
|
||||
print(type(m.encoder_pos.pe[0]))
|
||||
print(m.graph)
|
||||
|
||||
|
||||
def main():
|
||||
generate_pt()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20220809)
|
||||
main()
|
@ -0,0 +1,39 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
|
||||
import ncnn
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
x = torch.rand(100, 512) # (T, C)
|
||||
m = torch.jit.load("foo/encoder_pos.pt")
|
||||
_, t = m(x.unsqueeze(0)) # bach size is 1
|
||||
t = t.squeeze(0) # (T, C)
|
||||
|
||||
param = "foo/encoder_pos.ncnn.param"
|
||||
model = "foo/encoder_pos.ncnn.bin"
|
||||
with ncnn.Net() as net:
|
||||
net.load_param(param)
|
||||
net.load_model(model)
|
||||
with net.create_extractor() as ex:
|
||||
ex.input("in0", ncnn.Mat(x.numpy()).clone())
|
||||
|
||||
ret, ncnn_out0 = ex.extract("out1")
|
||||
assert ret == 0, ret
|
||||
n = np.array(ncnn_out0)
|
||||
print(n.shape) # (6, 512), (T, C)
|
||||
n = torch.from_numpy(n)
|
||||
|
||||
print(t.reshape(-1)[:10])
|
||||
print(n.reshape(-1)[:10])
|
||||
assert torch.allclose(t, n), (t - n).abs().max()
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20220808)
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user