Test encoder_embed

This commit is contained in:
Fangjun Kuang 2022-08-09 20:03:38 +08:00
parent 365c6aa045
commit 10360bed41
4 changed files with 96 additions and 5 deletions

View File

@ -1611,9 +1611,8 @@ class Conv2dSubsampling(nn.Module):
x = self.conv(x)
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
if torch.jit.is_tracing() and self.for_ncnn:
x = self.out(
x.transpose(1, 2).contiguous().view(1, -1, self.conv_out_dim)
)
x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim)
x = self.out(x)
else:
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))

View File

@ -55,8 +55,7 @@ class NonScaledNorm(nn.Module):
if not torch.jit.is_tracing():
assert x.shape[self.channel_dim] == self.num_channels
scales = (
torch.mean(x.pow(2), dim=self.channel_dim, keepdim=True)
+ self.eps_exp
torch.mean(x * x, dim=self.channel_dim, keepdim=True) + self.eps_exp
).pow(-0.5)
return x * scales

View File

@ -0,0 +1,50 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
from conformer import Conv2dSubsampling
from scaling_converter import convert_scaled_to_non_scaled
class Foo(nn.Module):
def __init__(self):
super().__init__()
num_features = 80
subsampling_factor = 4
d_model = 512
self.num_features = num_features
self.subsampling_factor = subsampling_factor
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
def forward(self, x: torch.Tensor):
"""
Args:
x:
(N, T, C)
"""
x = self.encoder_embed(x)
return x
def generate_pt():
f = Foo()
f.eval()
f = convert_scaled_to_non_scaled(f)
f.encoder_embed.for_ncnn = True
x = torch.rand(1, 30, 80)
y = f(x)
print("y.shape", y.shape)
m = torch.jit.trace(f, x)
m.save("foo/encoder_embed.pt")
def main():
generate_pt()
if __name__ == "__main__":
torch.manual_seed(20220809)
main()

View File

@ -0,0 +1,43 @@
#!/usr/bin/env python3
import math
import ncnn
import numpy as np
import torch
LOG_EPS = math.log(1e-10)
@torch.no_grad()
def main():
x = torch.rand(30, 80) # (T, C)
m = torch.jit.load("foo/encoder_embed.pt")
t = m(x.unsqueeze(0)) # bach size is 1
t = t.squeeze(0) # (T, C)
print(t.shape)
param = "foo/encoder_embed.ncnn.param"
model = "foo/encoder_embed.ncnn.bin"
with ncnn.Net() as net:
net.load_param(param)
net.load_model(model)
with net.create_extractor() as ex:
ex.input("in0", ncnn.Mat(x.numpy()).clone())
ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret
n = np.array(ncnn_out0)
print(n.shape) # (6, 512), (T, C)
n = torch.from_numpy(n)
print(t.reshape(-1)[:10])
print(n.reshape(-1)[:10])
assert torch.allclose(t, n, atol=1e-2), (t - n).abs().max()
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
torch.manual_seed(20220808)
main()