mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
fix decoding for ncnn (#828)
This commit is contained in:
parent
fcffa593f0
commit
c05f5d76df
@ -131,6 +131,8 @@ class Model:
|
||||
encoder_net = ncnn.Net()
|
||||
encoder_net.opt.use_packing_layout = False
|
||||
encoder_net.opt.use_fp16_storage = False
|
||||
encoder_net.opt.num_threads = 4
|
||||
|
||||
encoder_param = args.encoder_param_filename
|
||||
encoder_model = args.encoder_bin_filename
|
||||
|
||||
@ -144,6 +146,7 @@ class Model:
|
||||
decoder_model = args.decoder_bin_filename
|
||||
|
||||
decoder_net = ncnn.Net()
|
||||
decoder_net.opt.num_threads = 4
|
||||
|
||||
decoder_net.load_param(decoder_param)
|
||||
decoder_net.load_model(decoder_model)
|
||||
@ -154,6 +157,8 @@ class Model:
|
||||
joiner_param = args.joiner_param_filename
|
||||
joiner_model = args.joiner_bin_filename
|
||||
joiner_net = ncnn.Net()
|
||||
joiner_net.opt.num_threads = 4
|
||||
|
||||
joiner_net.load_param(joiner_param)
|
||||
joiner_net.load_model(joiner_model)
|
||||
|
||||
@ -176,7 +181,6 @@ class Model:
|
||||
- next_states, a list of tensors containing the next states
|
||||
"""
|
||||
with self.encoder_net.create_extractor() as ex:
|
||||
ex.set_num_threads(4)
|
||||
ex.input("in0", ncnn.Mat(x.numpy()).clone())
|
||||
|
||||
# layer0 in2-in5
|
||||
@ -220,7 +224,6 @@ class Model:
|
||||
assert decoder_input.dtype == torch.int32
|
||||
|
||||
with self.decoder_net.create_extractor() as ex:
|
||||
ex.set_num_threads(4)
|
||||
ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone())
|
||||
ret, ncnn_out0 = ex.extract("out0")
|
||||
assert ret == 0, ret
|
||||
@ -229,7 +232,6 @@ class Model:
|
||||
|
||||
def run_joiner(self, encoder_out, decoder_out):
|
||||
with self.joiner_net.create_extractor() as ex:
|
||||
ex.set_num_threads(4)
|
||||
ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone())
|
||||
ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone())
|
||||
ret, ncnn_out0 = ex.extract("out0")
|
||||
|
@ -104,6 +104,8 @@ class Model:
|
||||
encoder_net = ncnn.Net()
|
||||
encoder_net.opt.use_packing_layout = False
|
||||
encoder_net.opt.use_fp16_storage = False
|
||||
encoder_net.opt.num_threads = 4
|
||||
|
||||
encoder_param = args.encoder_param_filename
|
||||
encoder_model = args.encoder_bin_filename
|
||||
|
||||
@ -118,6 +120,7 @@ class Model:
|
||||
|
||||
decoder_net = ncnn.Net()
|
||||
decoder_net.opt.use_packing_layout = False
|
||||
decoder_net.opt.num_threads = 4
|
||||
|
||||
decoder_net.load_param(decoder_param)
|
||||
decoder_net.load_model(decoder_model)
|
||||
@ -129,6 +132,8 @@ class Model:
|
||||
joiner_model = args.joiner_bin_filename
|
||||
joiner_net = ncnn.Net()
|
||||
joiner_net.opt.use_packing_layout = False
|
||||
joiner_net.opt.num_threads = 4
|
||||
|
||||
joiner_net.load_param(joiner_param)
|
||||
joiner_net.load_model(joiner_model)
|
||||
|
||||
@ -136,7 +141,6 @@ class Model:
|
||||
|
||||
def run_encoder(self, x, states):
|
||||
with self.encoder_net.create_extractor() as ex:
|
||||
ex.set_num_threads(10)
|
||||
ex.input("in0", ncnn.Mat(x.numpy()).clone())
|
||||
x_lens = torch.tensor([x.size(0)], dtype=torch.float32)
|
||||
ex.input("in1", ncnn.Mat(x_lens.numpy()).clone())
|
||||
@ -165,7 +169,6 @@ class Model:
|
||||
assert decoder_input.dtype == torch.int32
|
||||
|
||||
with self.decoder_net.create_extractor() as ex:
|
||||
ex.set_num_threads(10)
|
||||
ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone())
|
||||
ret, ncnn_out0 = ex.extract("out0")
|
||||
assert ret == 0, ret
|
||||
@ -174,7 +177,6 @@ class Model:
|
||||
|
||||
def run_joiner(self, encoder_out, decoder_out):
|
||||
with self.joiner_net.create_extractor() as ex:
|
||||
ex.set_num_threads(10)
|
||||
ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone())
|
||||
ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone())
|
||||
ret, ncnn_out0 = ex.extract("out0")
|
||||
|
@ -92,6 +92,8 @@ class Model:
|
||||
encoder_net = ncnn.Net()
|
||||
encoder_net.opt.use_packing_layout = False
|
||||
encoder_net.opt.use_fp16_storage = False
|
||||
encoder_net.opt.num_threads = 4
|
||||
|
||||
encoder_param = args.encoder_param_filename
|
||||
encoder_model = args.encoder_bin_filename
|
||||
|
||||
@ -106,6 +108,7 @@ class Model:
|
||||
|
||||
decoder_net = ncnn.Net()
|
||||
decoder_net.opt.use_packing_layout = False
|
||||
decoder_net.opt.num_threads = 4
|
||||
|
||||
decoder_net.load_param(decoder_param)
|
||||
decoder_net.load_model(decoder_model)
|
||||
@ -117,6 +120,8 @@ class Model:
|
||||
joiner_model = args.joiner_bin_filename
|
||||
joiner_net = ncnn.Net()
|
||||
joiner_net.opt.use_packing_layout = False
|
||||
joiner_net.opt.num_threads = 4
|
||||
|
||||
joiner_net.load_param(joiner_param)
|
||||
joiner_net.load_model(joiner_model)
|
||||
|
||||
@ -124,7 +129,6 @@ class Model:
|
||||
|
||||
def run_encoder(self, x, states):
|
||||
with self.encoder_net.create_extractor() as ex:
|
||||
# ex.set_num_threads(10)
|
||||
ex.input("in0", ncnn.Mat(x.numpy()).clone())
|
||||
x_lens = torch.tensor([x.size(0)], dtype=torch.float32)
|
||||
ex.input("in1", ncnn.Mat(x_lens.numpy()).clone())
|
||||
@ -153,7 +157,6 @@ class Model:
|
||||
assert decoder_input.dtype == torch.int32
|
||||
|
||||
with self.decoder_net.create_extractor() as ex:
|
||||
# ex.set_num_threads(10)
|
||||
ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone())
|
||||
ret, ncnn_out0 = ex.extract("out0")
|
||||
assert ret == 0, ret
|
||||
@ -162,7 +165,6 @@ class Model:
|
||||
|
||||
def run_joiner(self, encoder_out, decoder_out):
|
||||
with self.joiner_net.create_extractor() as ex:
|
||||
# ex.set_num_threads(10)
|
||||
ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone())
|
||||
ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone())
|
||||
ret, ncnn_out0 = ex.extract("out0")
|
||||
|
Loading…
x
Reference in New Issue
Block a user