mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Test encoder_embed
This commit is contained in:
parent
365c6aa045
commit
10360bed41
@ -1611,9 +1611,8 @@ class Conv2dSubsampling(nn.Module):
|
||||
x = self.conv(x)
|
||||
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
||||
if torch.jit.is_tracing() and self.for_ncnn:
|
||||
x = self.out(
|
||||
x.transpose(1, 2).contiguous().view(1, -1, self.conv_out_dim)
|
||||
)
|
||||
x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim)
|
||||
x = self.out(x)
|
||||
else:
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
|
@ -55,8 +55,7 @@ class NonScaledNorm(nn.Module):
|
||||
if not torch.jit.is_tracing():
|
||||
assert x.shape[self.channel_dim] == self.num_channels
|
||||
scales = (
|
||||
torch.mean(x.pow(2), dim=self.channel_dim, keepdim=True)
|
||||
+ self.eps_exp
|
||||
torch.mean(x * x, dim=self.channel_dim, keepdim=True) + self.eps_exp
|
||||
).pow(-0.5)
|
||||
return x * scales
|
||||
|
||||
|
@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from conformer import Conv2dSubsampling
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
|
||||
|
||||
class Foo(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
num_features = 80
|
||||
subsampling_factor = 4
|
||||
d_model = 512
|
||||
self.num_features = num_features
|
||||
self.subsampling_factor = subsampling_factor
|
||||
|
||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
(N, T, C)
|
||||
"""
|
||||
x = self.encoder_embed(x)
|
||||
return x
|
||||
|
||||
|
||||
def generate_pt():
|
||||
f = Foo()
|
||||
f.eval()
|
||||
f = convert_scaled_to_non_scaled(f)
|
||||
f.encoder_embed.for_ncnn = True
|
||||
x = torch.rand(1, 30, 80)
|
||||
y = f(x)
|
||||
print("y.shape", y.shape)
|
||||
m = torch.jit.trace(f, x)
|
||||
m.save("foo/encoder_embed.pt")
|
||||
|
||||
|
||||
def main():
|
||||
generate_pt()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20220809)
|
||||
main()
|
43
egs/librispeech/ASR/pruned_transducer_stateless3/test_ncnn_encoder_embed.py
Executable file
43
egs/librispeech/ASR/pruned_transducer_stateless3/test_ncnn_encoder_embed.py
Executable file
@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import math
|
||||
|
||||
import ncnn
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
x = torch.rand(30, 80) # (T, C)
|
||||
m = torch.jit.load("foo/encoder_embed.pt")
|
||||
t = m(x.unsqueeze(0)) # bach size is 1
|
||||
t = t.squeeze(0) # (T, C)
|
||||
print(t.shape)
|
||||
|
||||
param = "foo/encoder_embed.ncnn.param"
|
||||
model = "foo/encoder_embed.ncnn.bin"
|
||||
with ncnn.Net() as net:
|
||||
net.load_param(param)
|
||||
net.load_model(model)
|
||||
with net.create_extractor() as ex:
|
||||
ex.input("in0", ncnn.Mat(x.numpy()).clone())
|
||||
|
||||
ret, ncnn_out0 = ex.extract("out0")
|
||||
assert ret == 0, ret
|
||||
n = np.array(ncnn_out0)
|
||||
print(n.shape) # (6, 512), (T, C)
|
||||
n = torch.from_numpy(n)
|
||||
|
||||
print(t.reshape(-1)[:10])
|
||||
print(n.reshape(-1)[:10])
|
||||
assert torch.allclose(t, n, atol=1e-2), (t - n).abs().max()
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20220808)
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user