mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Test encoder_embed
This commit is contained in:
parent
365c6aa045
commit
10360bed41
@ -1611,9 +1611,8 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
||||||
if torch.jit.is_tracing() and self.for_ncnn:
|
if torch.jit.is_tracing() and self.for_ncnn:
|
||||||
x = self.out(
|
x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim)
|
||||||
x.transpose(1, 2).contiguous().view(1, -1, self.conv_out_dim)
|
x = self.out(x)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
b, c, t, f = x.size()
|
b, c, t, f = x.size()
|
||||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||||
|
@ -55,8 +55,7 @@ class NonScaledNorm(nn.Module):
|
|||||||
if not torch.jit.is_tracing():
|
if not torch.jit.is_tracing():
|
||||||
assert x.shape[self.channel_dim] == self.num_channels
|
assert x.shape[self.channel_dim] == self.num_channels
|
||||||
scales = (
|
scales = (
|
||||||
torch.mean(x.pow(2), dim=self.channel_dim, keepdim=True)
|
torch.mean(x * x, dim=self.channel_dim, keepdim=True) + self.eps_exp
|
||||||
+ self.eps_exp
|
|
||||||
).pow(-0.5)
|
).pow(-0.5)
|
||||||
return x * scales
|
return x * scales
|
||||||
|
|
||||||
|
@ -0,0 +1,50 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from conformer import Conv2dSubsampling
|
||||||
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
|
|
||||||
|
|
||||||
|
class Foo(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
num_features = 80
|
||||||
|
subsampling_factor = 4
|
||||||
|
d_model = 512
|
||||||
|
self.num_features = num_features
|
||||||
|
self.subsampling_factor = subsampling_factor
|
||||||
|
|
||||||
|
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
(N, T, C)
|
||||||
|
"""
|
||||||
|
x = self.encoder_embed(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def generate_pt():
|
||||||
|
f = Foo()
|
||||||
|
f.eval()
|
||||||
|
f = convert_scaled_to_non_scaled(f)
|
||||||
|
f.encoder_embed.for_ncnn = True
|
||||||
|
x = torch.rand(1, 30, 80)
|
||||||
|
y = f(x)
|
||||||
|
print("y.shape", y.shape)
|
||||||
|
m = torch.jit.trace(f, x)
|
||||||
|
m.save("foo/encoder_embed.pt")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
generate_pt()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
torch.manual_seed(20220809)
|
||||||
|
main()
|
43
egs/librispeech/ASR/pruned_transducer_stateless3/test_ncnn_encoder_embed.py
Executable file
43
egs/librispeech/ASR/pruned_transducer_stateless3/test_ncnn_encoder_embed.py
Executable file
@ -0,0 +1,43 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import ncnn
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
x = torch.rand(30, 80) # (T, C)
|
||||||
|
m = torch.jit.load("foo/encoder_embed.pt")
|
||||||
|
t = m(x.unsqueeze(0)) # bach size is 1
|
||||||
|
t = t.squeeze(0) # (T, C)
|
||||||
|
print(t.shape)
|
||||||
|
|
||||||
|
param = "foo/encoder_embed.ncnn.param"
|
||||||
|
model = "foo/encoder_embed.ncnn.bin"
|
||||||
|
with ncnn.Net() as net:
|
||||||
|
net.load_param(param)
|
||||||
|
net.load_model(model)
|
||||||
|
with net.create_extractor() as ex:
|
||||||
|
ex.input("in0", ncnn.Mat(x.numpy()).clone())
|
||||||
|
|
||||||
|
ret, ncnn_out0 = ex.extract("out0")
|
||||||
|
assert ret == 0, ret
|
||||||
|
n = np.array(ncnn_out0)
|
||||||
|
print(n.shape) # (6, 512), (T, C)
|
||||||
|
n = torch.from_numpy(n)
|
||||||
|
|
||||||
|
print(t.reshape(-1)[:10])
|
||||||
|
print(n.reshape(-1)[:10])
|
||||||
|
assert torch.allclose(t, n, atol=1e-2), (t - n).abs().max()
|
||||||
|
|
||||||
|
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
if __name__ == "__main__":
|
||||||
|
torch.manual_seed(20220808)
|
||||||
|
main()
|
Loading…
x
Reference in New Issue
Block a user