mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Test ScaledConv2D.
This commit is contained in:
parent
6af5a82d8f
commit
b406c1beff
56
egs/librispeech/ASR/pruned_transducer_stateless3/test_ncnn.py
Executable file
56
egs/librispeech/ASR/pruned_transducer_stateless3/test_ncnn.py
Executable file
@ -0,0 +1,56 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import ncnn
|
||||
import numpy as np
|
||||
import torch
|
||||
from scaling import ScaledConv2d
|
||||
from scaling_converter import scaled_conv2d_to_conv2d
|
||||
|
||||
|
||||
def generate_scaled_conv2d():
|
||||
f = ScaledConv2d(in_channels=1, out_channels=2, kernel_size=3, padding=1)
|
||||
f = scaled_conv2d_to_conv2d(f)
|
||||
print(f)
|
||||
x = torch.rand(1, 1, 6, 8) # NCHW
|
||||
m = torch.jit.trace(f, x)
|
||||
m.save("foo/scaled_conv2d.pt")
|
||||
print(m.graph)
|
||||
|
||||
|
||||
def compare_scaled_conv2d():
|
||||
param = "foo/scaled_conv2d.ncnn.param"
|
||||
model = "foo/scaled_conv2d.ncnn.bin"
|
||||
|
||||
with ncnn.Net() as net:
|
||||
with net.create_extractor() as ex:
|
||||
net = ncnn.Net()
|
||||
net.load_param(param)
|
||||
net.load_model(model)
|
||||
|
||||
ex = net.create_extractor()
|
||||
x = torch.rand(1, 6, 5) # CHW
|
||||
ex.input("in0", ncnn.Mat(x.numpy()).clone())
|
||||
ret, out0 = ex.extract("out0")
|
||||
assert ret == 0
|
||||
out0 = np.array(out0)
|
||||
out0 = torch.from_numpy(out0)
|
||||
|
||||
m = torch.jit.load("foo/scaled_conv2d.pt")
|
||||
y = m(x.unsqueeze(0)).squeeze(0)
|
||||
|
||||
assert torch.allclose(out0, y, atol=1e-3), (out0 - y).abs().max()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
if not Path("foo/scaled_conv2d.ncnn.param").is_file():
|
||||
generate_scaled_conv2d()
|
||||
else:
|
||||
compare_scaled_conv2d()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20220803)
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user