From 317f47fb3782bd7394befc011c228fd612f6b793 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 9 Aug 2022 23:32:28 +0800 Subject: [PATCH] test RelPositionalEncoding --- .../test_generate_relpositional_encoding.py | 49 +++++++++++++++++++ .../test_ncnn_relpositional_encoding.py | 39 +++++++++++++++ 2 files changed, 88 insertions(+) create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless3/test_generate_relpositional_encoding.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless3/test_ncnn_relpositional_encoding.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_generate_relpositional_encoding.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_generate_relpositional_encoding.py new file mode 100755 index 000000000..915b3cd47 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_generate_relpositional_encoding.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 + + +import torch +import torch.nn as nn +from conformer import RelPositionalEncoding +from scaling_converter import convert_scaled_to_non_scaled + + +class Foo(nn.Module): + def __init__(self): + super().__init__() + d_model = 512 + dropout = 0.1 + + self.encoder_pos = RelPositionalEncoding(d_model, dropout) + + def forward(self, x: torch.Tensor): + """ + Args: + x: + (N, T, C) + """ + y, pos_emb = self.encoder_pos(x) + return y, pos_emb + + +def generate_pt(): + f = Foo() + f.eval() + f = convert_scaled_to_non_scaled(f) + x = torch.rand(1, 6, 4) + y, pos_emb = f(x) + print("y.shape", y.shape) + print("pos_emb.shape", pos_emb.shape) + m = torch.jit.trace(f, x) + m.save("foo/encoder_pos.pt") + print(m.encoder_pos.pe[0].shape) + print(type(m.encoder_pos.pe[0])) + print(m.graph) + + +def main(): + generate_pt() + + +if __name__ == "__main__": + torch.manual_seed(20220809) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_ncnn_relpositional_encoding.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_ncnn_relpositional_encoding.py new file mode 100755 index 000000000..1ac1e526e --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_ncnn_relpositional_encoding.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 + + +import ncnn +import numpy as np +import torch + + +@torch.no_grad() +def main(): + x = torch.rand(100, 512) # (T, C) + m = torch.jit.load("foo/encoder_pos.pt") + _, t = m(x.unsqueeze(0)) # bach size is 1 + t = t.squeeze(0) # (T, C) + + param = "foo/encoder_pos.ncnn.param" + model = "foo/encoder_pos.ncnn.bin" + with ncnn.Net() as net: + net.load_param(param) + net.load_model(model) + with net.create_extractor() as ex: + ex.input("in0", ncnn.Mat(x.numpy()).clone()) + + ret, ncnn_out0 = ex.extract("out1") + assert ret == 0, ret + n = np.array(ncnn_out0) + print(n.shape) # (6, 512), (T, C) + n = torch.from_numpy(n) + + print(t.reshape(-1)[:10]) + print(n.reshape(-1)[:10]) + assert torch.allclose(t, n), (t - n).abs().max() + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) +if __name__ == "__main__": + torch.manual_seed(20220808) + main()