mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Support converting RelPositionalEncoding.
This commit is contained in:
parent
d69d83a83e
commit
365c6aa045
@ -798,7 +798,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
|
||||
self.d_model = d_model
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
self.pe = None
|
||||
self.register_buffer("pe", None)
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||
|
||||
def extend_pe(self, x: Tensor, left_context: int = 0) -> None:
|
||||
@ -835,7 +835,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
||||
pe_negative = pe_negative[1:].unsqueeze(0)
|
||||
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
self.register_buffer("pe", pe.to(device=x.device, dtype=x.dtype))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -11,26 +11,38 @@ LOG_EPS = math.log(1e-10)
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
x = torch.rand(1, 200, 80)
|
||||
f = torch.jit.load("foo/encoder_embed.pt")
|
||||
x = torch.rand(10, 3)
|
||||
f = torch.jit.load("foo/encoder_pos.pt")
|
||||
|
||||
param = "foo/encoder_embed.ncnn.param"
|
||||
model = "foo/encoder_embed.ncnn.bin"
|
||||
param = "foo/encoder_pos.ncnn.param"
|
||||
model = "foo/encoder_pos.ncnn.bin"
|
||||
|
||||
with ncnn.Net() as net:
|
||||
net.load_param(param)
|
||||
net.load_model(model)
|
||||
with net.create_extractor() as ex:
|
||||
ex.input("in0", ncnn.Mat(x.numpy()).clone())
|
||||
ret, out0 = ex.extract("out0")
|
||||
assert ret == 0
|
||||
out0 = np.array(out0)
|
||||
print("ncnn", out0.shape)
|
||||
t = f(x)
|
||||
out0 = torch.from_numpy(out0)
|
||||
t = t.squeeze(0)
|
||||
print("torch", t.shape)
|
||||
torch.allclose(out0, t), (t - out0).abs().max()
|
||||
ret, ncnn_out0 = ex.extract("out0")
|
||||
assert ret == 0, ret
|
||||
ncnn_out0 = np.array(ncnn_out0)
|
||||
|
||||
ret, ncnn_out1 = ex.extract("out1")
|
||||
assert ret == 0, ret
|
||||
ncnn_out1 = np.array(ncnn_out1)
|
||||
|
||||
torch_out0, torch_out1 = f(x.unsqueeze(0))
|
||||
torch_out0 = torch_out0.squeeze(0)
|
||||
torch_out1 = torch_out1.squeeze(1)
|
||||
|
||||
ncnn_out0 = torch.from_numpy(ncnn_out0)
|
||||
ncnn_out1 = torch.from_numpy(ncnn_out1)
|
||||
|
||||
torch.allclose(torch_out0, ncnn_out0), (
|
||||
torch_out0 - ncnn_out0
|
||||
).abs().max()
|
||||
torch.allclose(torch_out1, ncnn_out1), (
|
||||
torch_out1 - ncnn_out1
|
||||
).abs().max()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -2,10 +2,40 @@
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from conformer import RelPositionalEncoding
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
|
||||
class Foo(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(3, 100)
|
||||
self.encoder_pos = RelPositionalEncoding(100, 0.1)
|
||||
self.linear2 = nn.Linear(100, 2)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.linear(x)
|
||||
z, embed = self.encoder_pos(y)
|
||||
return z, embed
|
||||
|
||||
|
||||
def test():
|
||||
f = Foo()
|
||||
f.eval()
|
||||
# f.encoder_pos.for_ncnn = True
|
||||
x = torch.rand(1, 10, 3)
|
||||
y, _ = f(x)
|
||||
print(y.shape)
|
||||
# print(embed.shape)
|
||||
|
||||
m = torch.jit.trace(f, x)
|
||||
m.save("foo/encoder_pos.pt")
|
||||
print(m.graph)
|
||||
# print(m.encoder_pos.graph)
|
||||
|
||||
|
||||
def get_model():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
@ -24,20 +54,24 @@ def get_model():
|
||||
|
||||
def test_encoder_embedding():
|
||||
model = get_model()
|
||||
model.eval()
|
||||
model = convert_scaled_to_non_scaled(model)
|
||||
|
||||
f = model.encoder.encoder_embed
|
||||
f = model.encoder
|
||||
f.for_ncnn = True
|
||||
f.encoder_pos.for_ncnn = True
|
||||
|
||||
f.for_ncnn = True
|
||||
print(f)
|
||||
x = torch.rand(1, 100, 80) # NTC
|
||||
m = torch.jit.trace(f, x)
|
||||
m.save("foo/encoder_embed.pt")
|
||||
x_lens = torch.tensor([100])
|
||||
m = torch.jit.trace(f, (x, x_lens))
|
||||
m.save("foo/encoder_pos.pt")
|
||||
print(m.graph)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
test_encoder_embedding()
|
||||
# test_encoder_embedding()
|
||||
test()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user