fix decoding for ncnn (#828)

This commit is contained in:
Fangjun Kuang 2023-01-10 20:52:13 +08:00 committed by GitHub
parent fcffa593f0
commit c05f5d76df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 9 deletions

View File

@ -131,6 +131,8 @@ class Model:
encoder_net = ncnn.Net() encoder_net = ncnn.Net()
encoder_net.opt.use_packing_layout = False encoder_net.opt.use_packing_layout = False
encoder_net.opt.use_fp16_storage = False encoder_net.opt.use_fp16_storage = False
encoder_net.opt.num_threads = 4
encoder_param = args.encoder_param_filename encoder_param = args.encoder_param_filename
encoder_model = args.encoder_bin_filename encoder_model = args.encoder_bin_filename
@ -144,6 +146,7 @@ class Model:
decoder_model = args.decoder_bin_filename decoder_model = args.decoder_bin_filename
decoder_net = ncnn.Net() decoder_net = ncnn.Net()
decoder_net.opt.num_threads = 4
decoder_net.load_param(decoder_param) decoder_net.load_param(decoder_param)
decoder_net.load_model(decoder_model) decoder_net.load_model(decoder_model)
@ -154,6 +157,8 @@ class Model:
joiner_param = args.joiner_param_filename joiner_param = args.joiner_param_filename
joiner_model = args.joiner_bin_filename joiner_model = args.joiner_bin_filename
joiner_net = ncnn.Net() joiner_net = ncnn.Net()
joiner_net.opt.num_threads = 4
joiner_net.load_param(joiner_param) joiner_net.load_param(joiner_param)
joiner_net.load_model(joiner_model) joiner_net.load_model(joiner_model)
@ -176,7 +181,6 @@ class Model:
- next_states, a list of tensors containing the next states - next_states, a list of tensors containing the next states
""" """
with self.encoder_net.create_extractor() as ex: with self.encoder_net.create_extractor() as ex:
ex.set_num_threads(4)
ex.input("in0", ncnn.Mat(x.numpy()).clone()) ex.input("in0", ncnn.Mat(x.numpy()).clone())
# layer0 in2-in5 # layer0 in2-in5
@ -220,7 +224,6 @@ class Model:
assert decoder_input.dtype == torch.int32 assert decoder_input.dtype == torch.int32
with self.decoder_net.create_extractor() as ex: with self.decoder_net.create_extractor() as ex:
ex.set_num_threads(4)
ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone()) ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone())
ret, ncnn_out0 = ex.extract("out0") ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret assert ret == 0, ret
@ -229,7 +232,6 @@ class Model:
def run_joiner(self, encoder_out, decoder_out): def run_joiner(self, encoder_out, decoder_out):
with self.joiner_net.create_extractor() as ex: with self.joiner_net.create_extractor() as ex:
ex.set_num_threads(4)
ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone()) ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone())
ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone()) ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone())
ret, ncnn_out0 = ex.extract("out0") ret, ncnn_out0 = ex.extract("out0")

View File

@ -104,6 +104,8 @@ class Model:
encoder_net = ncnn.Net() encoder_net = ncnn.Net()
encoder_net.opt.use_packing_layout = False encoder_net.opt.use_packing_layout = False
encoder_net.opt.use_fp16_storage = False encoder_net.opt.use_fp16_storage = False
encoder_net.opt.num_threads = 4
encoder_param = args.encoder_param_filename encoder_param = args.encoder_param_filename
encoder_model = args.encoder_bin_filename encoder_model = args.encoder_bin_filename
@ -118,6 +120,7 @@ class Model:
decoder_net = ncnn.Net() decoder_net = ncnn.Net()
decoder_net.opt.use_packing_layout = False decoder_net.opt.use_packing_layout = False
decoder_net.opt.num_threads = 4
decoder_net.load_param(decoder_param) decoder_net.load_param(decoder_param)
decoder_net.load_model(decoder_model) decoder_net.load_model(decoder_model)
@ -129,6 +132,8 @@ class Model:
joiner_model = args.joiner_bin_filename joiner_model = args.joiner_bin_filename
joiner_net = ncnn.Net() joiner_net = ncnn.Net()
joiner_net.opt.use_packing_layout = False joiner_net.opt.use_packing_layout = False
joiner_net.opt.num_threads = 4
joiner_net.load_param(joiner_param) joiner_net.load_param(joiner_param)
joiner_net.load_model(joiner_model) joiner_net.load_model(joiner_model)
@ -136,7 +141,6 @@ class Model:
def run_encoder(self, x, states): def run_encoder(self, x, states):
with self.encoder_net.create_extractor() as ex: with self.encoder_net.create_extractor() as ex:
ex.set_num_threads(10)
ex.input("in0", ncnn.Mat(x.numpy()).clone()) ex.input("in0", ncnn.Mat(x.numpy()).clone())
x_lens = torch.tensor([x.size(0)], dtype=torch.float32) x_lens = torch.tensor([x.size(0)], dtype=torch.float32)
ex.input("in1", ncnn.Mat(x_lens.numpy()).clone()) ex.input("in1", ncnn.Mat(x_lens.numpy()).clone())
@ -165,7 +169,6 @@ class Model:
assert decoder_input.dtype == torch.int32 assert decoder_input.dtype == torch.int32
with self.decoder_net.create_extractor() as ex: with self.decoder_net.create_extractor() as ex:
ex.set_num_threads(10)
ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone()) ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone())
ret, ncnn_out0 = ex.extract("out0") ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret assert ret == 0, ret
@ -174,7 +177,6 @@ class Model:
def run_joiner(self, encoder_out, decoder_out): def run_joiner(self, encoder_out, decoder_out):
with self.joiner_net.create_extractor() as ex: with self.joiner_net.create_extractor() as ex:
ex.set_num_threads(10)
ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone()) ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone())
ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone()) ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone())
ret, ncnn_out0 = ex.extract("out0") ret, ncnn_out0 = ex.extract("out0")

View File

@ -92,6 +92,8 @@ class Model:
encoder_net = ncnn.Net() encoder_net = ncnn.Net()
encoder_net.opt.use_packing_layout = False encoder_net.opt.use_packing_layout = False
encoder_net.opt.use_fp16_storage = False encoder_net.opt.use_fp16_storage = False
encoder_net.opt.num_threads = 4
encoder_param = args.encoder_param_filename encoder_param = args.encoder_param_filename
encoder_model = args.encoder_bin_filename encoder_model = args.encoder_bin_filename
@ -106,6 +108,7 @@ class Model:
decoder_net = ncnn.Net() decoder_net = ncnn.Net()
decoder_net.opt.use_packing_layout = False decoder_net.opt.use_packing_layout = False
decoder_net.opt.num_threads = 4
decoder_net.load_param(decoder_param) decoder_net.load_param(decoder_param)
decoder_net.load_model(decoder_model) decoder_net.load_model(decoder_model)
@ -117,6 +120,8 @@ class Model:
joiner_model = args.joiner_bin_filename joiner_model = args.joiner_bin_filename
joiner_net = ncnn.Net() joiner_net = ncnn.Net()
joiner_net.opt.use_packing_layout = False joiner_net.opt.use_packing_layout = False
joiner_net.opt.num_threads = 4
joiner_net.load_param(joiner_param) joiner_net.load_param(joiner_param)
joiner_net.load_model(joiner_model) joiner_net.load_model(joiner_model)
@ -124,7 +129,6 @@ class Model:
def run_encoder(self, x, states): def run_encoder(self, x, states):
with self.encoder_net.create_extractor() as ex: with self.encoder_net.create_extractor() as ex:
# ex.set_num_threads(10)
ex.input("in0", ncnn.Mat(x.numpy()).clone()) ex.input("in0", ncnn.Mat(x.numpy()).clone())
x_lens = torch.tensor([x.size(0)], dtype=torch.float32) x_lens = torch.tensor([x.size(0)], dtype=torch.float32)
ex.input("in1", ncnn.Mat(x_lens.numpy()).clone()) ex.input("in1", ncnn.Mat(x_lens.numpy()).clone())
@ -153,7 +157,6 @@ class Model:
assert decoder_input.dtype == torch.int32 assert decoder_input.dtype == torch.int32
with self.decoder_net.create_extractor() as ex: with self.decoder_net.create_extractor() as ex:
# ex.set_num_threads(10)
ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone()) ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone())
ret, ncnn_out0 = ex.extract("out0") ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret assert ret == 0, ret
@ -162,7 +165,6 @@ class Model:
def run_joiner(self, encoder_out, decoder_out): def run_joiner(self, encoder_out, decoder_out):
with self.joiner_net.create_extractor() as ex: with self.joiner_net.create_extractor() as ex:
# ex.set_num_threads(10)
ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone()) ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone())
ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone()) ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone())
ret, ncnn_out0 = ex.extract("out0") ret, ncnn_out0 = ex.extract("out0")