Fix CI tests (#1266)

This commit is contained in:
Fangjun Kuang 2023-09-21 21:16:14 +08:00 committed by GitHub
parent 45d60ef262
commit f5dc957d44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 62 additions and 0 deletions

View File

@ -151,12 +151,14 @@ class OnnxModel:
self.encoder = ort.InferenceSession( self.encoder = ort.InferenceSession(
encoder_model_filename, encoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
def init_decoder(self, decoder_model_filename: str): def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession( self.decoder = ort.InferenceSession(
decoder_model_filename, decoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -170,6 +172,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession( self.joiner = ort.InferenceSession(
joiner_model_filename, joiner_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

View File

@ -152,12 +152,14 @@ class OnnxModel:
self.encoder = ort.InferenceSession( self.encoder = ort.InferenceSession(
encoder_model_filename, encoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
def init_decoder(self, decoder_model_filename: str): def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession( self.decoder = ort.InferenceSession(
decoder_model_filename, decoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -171,6 +173,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession( self.joiner = ort.InferenceSession(
joiner_model_filename, joiner_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

View File

@ -136,6 +136,7 @@ class OnnxModel:
self.encoder = ort.InferenceSession( self.encoder = ort.InferenceSession(
encoder_model_filename, encoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
self.init_encoder_states() self.init_encoder_states()
@ -184,6 +185,7 @@ class OnnxModel:
self.decoder = ort.InferenceSession( self.decoder = ort.InferenceSession(
decoder_model_filename, decoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -197,6 +199,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession( self.joiner = ort.InferenceSession(
joiner_model_filename, joiner_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

View File

@ -129,6 +129,7 @@ class OnnxModel:
self.encoder = ort.InferenceSession( self.encoder = ort.InferenceSession(
encoder_model_filename, encoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
self.init_encoder_states() self.init_encoder_states()
@ -166,6 +167,7 @@ class OnnxModel:
self.decoder = ort.InferenceSession( self.decoder = ort.InferenceSession(
decoder_model_filename, decoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -179,6 +181,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession( self.joiner = ort.InferenceSession(
joiner_model_filename, joiner_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

View File

@ -172,30 +172,35 @@ class Model:
self.encoder = ort.InferenceSession( self.encoder = ort.InferenceSession(
args.encoder_model_filename, args.encoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
def init_decoder(self, args): def init_decoder(self, args):
self.decoder = ort.InferenceSession( self.decoder = ort.InferenceSession(
args.decoder_model_filename, args.decoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
def init_joiner(self, args): def init_joiner(self, args):
self.joiner = ort.InferenceSession( self.joiner = ort.InferenceSession(
args.joiner_model_filename, args.joiner_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
def init_joiner_encoder_proj(self, args): def init_joiner_encoder_proj(self, args):
self.joiner_encoder_proj = ort.InferenceSession( self.joiner_encoder_proj = ort.InferenceSession(
args.joiner_encoder_proj_model_filename, args.joiner_encoder_proj_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
def init_joiner_decoder_proj(self, args): def init_joiner_decoder_proj(self, args):
self.joiner_decoder_proj = ort.InferenceSession( self.joiner_decoder_proj = ort.InferenceSession(
args.joiner_decoder_proj_model_filename, args.joiner_decoder_proj_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
def run_encoder(self, x, h0, c0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def run_encoder(self, x, h0, c0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

View File

@ -150,12 +150,14 @@ class OnnxModel:
self.encoder = ort.InferenceSession( self.encoder = ort.InferenceSession(
encoder_model_filename, encoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
def init_decoder(self, decoder_model_filename: str): def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession( self.decoder = ort.InferenceSession(
decoder_model_filename, decoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -169,6 +171,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession( self.joiner = ort.InferenceSession(
joiner_model_filename, joiner_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

View File

@ -78,6 +78,7 @@ def test_conv2d_subsampling():
session = ort.InferenceSession( session = ort.InferenceSession(
filename, filename,
sess_options=options, sess_options=options,
providers=["CPUExecutionProvider"],
) )
input_nodes = session.get_inputs() input_nodes = session.get_inputs()
@ -133,6 +134,7 @@ def test_rel_pos():
session = ort.InferenceSession( session = ort.InferenceSession(
filename, filename,
sess_options=options, sess_options=options,
providers=["CPUExecutionProvider"],
) )
input_nodes = session.get_inputs() input_nodes = session.get_inputs()
@ -220,6 +222,7 @@ def test_conformer_encoder_layer():
session = ort.InferenceSession( session = ort.InferenceSession(
filename, filename,
sess_options=options, sess_options=options,
providers=["CPUExecutionProvider"],
) )
input_nodes = session.get_inputs() input_nodes = session.get_inputs()
@ -304,6 +307,7 @@ def test_conformer_encoder():
session = ort.InferenceSession( session = ort.InferenceSession(
filename, filename,
sess_options=options, sess_options=options,
providers=["CPUExecutionProvider"],
) )
input_nodes = session.get_inputs() input_nodes = session.get_inputs()
@ -359,6 +363,7 @@ def test_conformer():
session = ort.InferenceSession( session = ort.InferenceSession(
filename, filename,
sess_options=options, sess_options=options,
providers=["CPUExecutionProvider"],
) )
input_nodes = session.get_inputs() input_nodes = session.get_inputs()

View File

@ -138,6 +138,7 @@ class OnnxModel:
self.encoder = ort.InferenceSession( self.encoder = ort.InferenceSession(
encoder_model_filename, encoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
self.init_encoder_states() self.init_encoder_states()
@ -185,6 +186,7 @@ class OnnxModel:
self.decoder = ort.InferenceSession( self.decoder = ort.InferenceSession(
decoder_model_filename, decoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -198,6 +200,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession( self.joiner = ort.InferenceSession(
joiner_model_filename, joiner_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

View File

@ -74,6 +74,7 @@ def test_conv2d_subsampling():
session = ort.InferenceSession( session = ort.InferenceSession(
filename, filename,
sess_options=options, sess_options=options,
providers=["CPUExecutionProvider"],
) )
input_nodes = session.get_inputs() input_nodes = session.get_inputs()
@ -128,6 +129,7 @@ def test_rel_pos():
session = ort.InferenceSession( session = ort.InferenceSession(
filename, filename,
sess_options=options, sess_options=options,
providers=["CPUExecutionProvider"],
) )
input_nodes = session.get_inputs() input_nodes = session.get_inputs()
@ -204,6 +206,7 @@ def test_zipformer_encoder_layer():
session = ort.InferenceSession( session = ort.InferenceSession(
filename, filename,
sess_options=options, sess_options=options,
providers=["CPUExecutionProvider"],
) )
input_nodes = session.get_inputs() input_nodes = session.get_inputs()
@ -284,6 +287,7 @@ def test_zipformer_encoder():
session = ort.InferenceSession( session = ort.InferenceSession(
filename, filename,
sess_options=options, sess_options=options,
providers=["CPUExecutionProvider"],
) )
input_nodes = session.get_inputs() input_nodes = session.get_inputs()
@ -338,6 +342,7 @@ def test_zipformer():
session = ort.InferenceSession( session = ort.InferenceSession(
filename, filename,
sess_options=options, sess_options=options,
providers=["CPUExecutionProvider"],
) )
input_nodes = session.get_inputs() input_nodes = session.get_inputs()

View File

@ -326,41 +326,49 @@ def main():
encoder = ort.InferenceSession( encoder = ort.InferenceSession(
args.encoder_model_filename, args.encoder_model_filename,
sess_options=session_opts, sess_options=session_opts,
providers=["CPUExecutionProvider"],
) )
decoder = ort.InferenceSession( decoder = ort.InferenceSession(
args.decoder_model_filename, args.decoder_model_filename,
sess_options=session_opts, sess_options=session_opts,
providers=["CPUExecutionProvider"],
) )
joiner = ort.InferenceSession( joiner = ort.InferenceSession(
args.joiner_model_filename, args.joiner_model_filename,
sess_options=session_opts, sess_options=session_opts,
providers=["CPUExecutionProvider"],
) )
joiner_encoder_proj = ort.InferenceSession( joiner_encoder_proj = ort.InferenceSession(
args.joiner_encoder_proj_model_filename, args.joiner_encoder_proj_model_filename,
sess_options=session_opts, sess_options=session_opts,
providers=["CPUExecutionProvider"],
) )
joiner_decoder_proj = ort.InferenceSession( joiner_decoder_proj = ort.InferenceSession(
args.joiner_decoder_proj_model_filename, args.joiner_decoder_proj_model_filename,
sess_options=session_opts, sess_options=session_opts,
providers=["CPUExecutionProvider"],
) )
lconv = ort.InferenceSession( lconv = ort.InferenceSession(
args.lconv_filename, args.lconv_filename,
sess_options=session_opts, sess_options=session_opts,
providers=["CPUExecutionProvider"],
) )
frame_reducer = ort.InferenceSession( frame_reducer = ort.InferenceSession(
args.frame_reducer_filename, args.frame_reducer_filename,
sess_options=session_opts, sess_options=session_opts,
providers=["CPUExecutionProvider"],
) )
ctc_output = ort.InferenceSession( ctc_output = ort.InferenceSession(
args.ctc_output_filename, args.ctc_output_filename,
sess_options=session_opts, sess_options=session_opts,
providers=["CPUExecutionProvider"],
) )
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()

View File

@ -130,6 +130,7 @@ class OnnxModel:
self.encoder = ort.InferenceSession( self.encoder = ort.InferenceSession(
encoder_model_filename, encoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
self.init_encoder_states() self.init_encoder_states()
@ -229,6 +230,7 @@ class OnnxModel:
self.decoder = ort.InferenceSession( self.decoder = ort.InferenceSession(
decoder_model_filename, decoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -242,6 +244,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession( self.joiner = ort.InferenceSession(
joiner_model_filename, joiner_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

View File

@ -146,6 +146,7 @@ class OnnxModel:
self.encoder = ort.InferenceSession( self.encoder = ort.InferenceSession(
encoder_model_filename, encoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
self.init_encoder_states() self.init_encoder_states()
@ -236,6 +237,7 @@ class OnnxModel:
self.decoder = ort.InferenceSession( self.decoder = ort.InferenceSession(
decoder_model_filename, decoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -249,6 +251,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession( self.joiner = ort.InferenceSession(
joiner_model_filename, joiner_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

View File

@ -151,12 +151,14 @@ class OnnxModel:
self.encoder = ort.InferenceSession( self.encoder = ort.InferenceSession(
encoder_model_filename, encoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
def init_decoder(self, decoder_model_filename: str): def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession( self.decoder = ort.InferenceSession(
decoder_model_filename, decoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -170,6 +172,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession( self.joiner = ort.InferenceSession(
joiner_model_filename, joiner_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

View File

@ -258,6 +258,7 @@ def main():
encoder_session = ort.InferenceSession( encoder_session = ort.InferenceSession(
args.onnx_encoder_filename, args.onnx_encoder_filename,
sess_options=options, sess_options=options,
providers=["CPUExecutionProvider"],
) )
test_encoder(model, encoder_session) test_encoder(model, encoder_session)
@ -265,6 +266,7 @@ def main():
decoder_session = ort.InferenceSession( decoder_session = ort.InferenceSession(
args.onnx_decoder_filename, args.onnx_decoder_filename,
sess_options=options, sess_options=options,
providers=["CPUExecutionProvider"],
) )
test_decoder(model, decoder_session) test_decoder(model, decoder_session)
@ -272,14 +274,17 @@ def main():
joiner_session = ort.InferenceSession( joiner_session = ort.InferenceSession(
args.onnx_joiner_filename, args.onnx_joiner_filename,
sess_options=options, sess_options=options,
providers=["CPUExecutionProvider"],
) )
joiner_encoder_proj_session = ort.InferenceSession( joiner_encoder_proj_session = ort.InferenceSession(
args.onnx_joiner_encoder_proj_filename, args.onnx_joiner_encoder_proj_filename,
sess_options=options, sess_options=options,
providers=["CPUExecutionProvider"],
) )
joiner_decoder_proj_session = ort.InferenceSession( joiner_decoder_proj_session = ort.InferenceSession(
args.onnx_joiner_decoder_proj_filename, args.onnx_joiner_decoder_proj_filename,
sess_options=options, sess_options=options,
providers=["CPUExecutionProvider"],
) )
test_joiner( test_joiner(
model, model,

View File

@ -139,6 +139,7 @@ class OnnxModel:
self.encoder = ort.InferenceSession( self.encoder = ort.InferenceSession(
encoder_model_filename, encoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
self.init_encoder_states() self.init_encoder_states()
@ -186,6 +187,7 @@ class OnnxModel:
self.decoder = ort.InferenceSession( self.decoder = ort.InferenceSession(
decoder_model_filename, decoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -199,6 +201,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession( self.joiner = ort.InferenceSession(
joiner_model_filename, joiner_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

View File

@ -158,12 +158,14 @@ class OnnxModel:
self.encoder = ort.InferenceSession( self.encoder = ort.InferenceSession(
encoder_model_filename, encoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
def init_decoder(self, decoder_model_filename: str): def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession( self.decoder = ort.InferenceSession(
decoder_model_filename, decoder_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -177,6 +179,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession( self.joiner = ort.InferenceSession(
joiner_model_filename, joiner_model_filename,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

View File

@ -54,6 +54,7 @@ class OnnxModel:
self.model = ort.InferenceSession( self.model = ort.InferenceSession(
nn_model, nn_model,
sess_options=self.session_opts, sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
) )
meta = self.model.get_modelmeta().custom_metadata_map meta = self.model.get_modelmeta().custom_metadata_map