Support converting RelPositionalEncoding.

This commit is contained in:
Fangjun Kuang 2022-08-08 15:27:04 +08:00
parent d69d83a83e
commit 365c6aa045
3 changed files with 67 additions and 21 deletions

View File

@ -798,7 +798,7 @@ class RelPositionalEncoding(torch.nn.Module):
self.d_model = d_model
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None
self.register_buffer("pe", None)
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x: Tensor, left_context: int = 0) -> None:
@ -835,7 +835,7 @@ class RelPositionalEncoding(torch.nn.Module):
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
self.register_buffer("pe", pe.to(device=x.device, dtype=x.dtype))
def forward(
self,

View File

@ -11,26 +11,38 @@ LOG_EPS = math.log(1e-10)
@torch.no_grad()
def main():
x = torch.rand(1, 200, 80)
f = torch.jit.load("foo/encoder_embed.pt")
x = torch.rand(10, 3)
f = torch.jit.load("foo/encoder_pos.pt")
param = "foo/encoder_embed.ncnn.param"
model = "foo/encoder_embed.ncnn.bin"
param = "foo/encoder_pos.ncnn.param"
model = "foo/encoder_pos.ncnn.bin"
with ncnn.Net() as net:
net.load_param(param)
net.load_model(model)
with net.create_extractor() as ex:
ex.input("in0", ncnn.Mat(x.numpy()).clone())
ret, out0 = ex.extract("out0")
assert ret == 0
out0 = np.array(out0)
print("ncnn", out0.shape)
t = f(x)
out0 = torch.from_numpy(out0)
t = t.squeeze(0)
print("torch", t.shape)
torch.allclose(out0, t), (t - out0).abs().max()
ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret
ncnn_out0 = np.array(ncnn_out0)
ret, ncnn_out1 = ex.extract("out1")
assert ret == 0, ret
ncnn_out1 = np.array(ncnn_out1)
torch_out0, torch_out1 = f(x.unsqueeze(0))
torch_out0 = torch_out0.squeeze(0)
torch_out1 = torch_out1.squeeze(1)
ncnn_out0 = torch.from_numpy(ncnn_out0)
ncnn_out1 = torch.from_numpy(ncnn_out1)
torch.allclose(torch_out0, ncnn_out0), (
torch_out0 - ncnn_out0
).abs().max()
torch.allclose(torch_out1, ncnn_out1), (
torch_out1 - ncnn_out1
).abs().max()
if __name__ == "__main__":

View File

@ -2,10 +2,40 @@
import torch
import torch.nn as nn
from conformer import RelPositionalEncoding
from scaling_converter import convert_scaled_to_non_scaled
from train import get_params, get_transducer_model
class Foo(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 100)
self.encoder_pos = RelPositionalEncoding(100, 0.1)
self.linear2 = nn.Linear(100, 2)
def forward(self, x):
y = self.linear(x)
z, embed = self.encoder_pos(y)
return z, embed
def test():
f = Foo()
f.eval()
# f.encoder_pos.for_ncnn = True
x = torch.rand(1, 10, 3)
y, _ = f(x)
print(y.shape)
# print(embed.shape)
m = torch.jit.trace(f, x)
m.save("foo/encoder_pos.pt")
print(m.graph)
# print(m.encoder_pos.graph)
def get_model():
params = get_params()
params.vocab_size = 500
@ -24,20 +54,24 @@ def get_model():
def test_encoder_embedding():
model = get_model()
model.eval()
model = convert_scaled_to_non_scaled(model)
f = model.encoder.encoder_embed
f = model.encoder
f.for_ncnn = True
f.encoder_pos.for_ncnn = True
f.for_ncnn = True
print(f)
x = torch.rand(1, 100, 80) # NTC
m = torch.jit.trace(f, x)
m.save("foo/encoder_embed.pt")
x_lens = torch.tensor([100])
m = torch.jit.trace(f, (x, x_lens))
m.save("foo/encoder_pos.pt")
print(m.graph)
@torch.no_grad()
def main():
test_encoder_embedding()
# test_encoder_embedding()
test()
if __name__ == "__main__":