mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 23:54:17 +00:00
Merge branch 'master' of github.com:marcoyang1998/icefall into train_rnnlm
This commit is contained in:
commit
b984fb1b43
@ -322,6 +322,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -151,12 +151,14 @@ class OnnxModel:
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_decoder(self, decoder_model_filename: str):
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
@ -170,6 +172,7 @@ class OnnxModel:
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
|
@ -330,6 +330,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -152,12 +152,14 @@ class OnnxModel:
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_decoder(self, decoder_model_filename: str):
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
@ -171,6 +173,7 @@ class OnnxModel:
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
|
@ -401,6 +401,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -136,6 +136,7 @@ class OnnxModel:
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self.init_encoder_states()
|
||||
|
||||
@ -184,6 +185,7 @@ class OnnxModel:
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
@ -197,6 +199,7 @@ class OnnxModel:
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
|
@ -359,6 +359,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -356,6 +356,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -129,6 +129,7 @@ class OnnxModel:
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self.init_encoder_states()
|
||||
|
||||
@ -166,6 +167,7 @@ class OnnxModel:
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
@ -179,6 +181,7 @@ class OnnxModel:
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
|
@ -172,30 +172,35 @@ class Model:
|
||||
self.encoder = ort.InferenceSession(
|
||||
args.encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_decoder(self, args):
|
||||
self.decoder = ort.InferenceSession(
|
||||
args.decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_joiner(self, args):
|
||||
self.joiner = ort.InferenceSession(
|
||||
args.joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_joiner_encoder_proj(self, args):
|
||||
self.joiner_encoder_proj = ort.InferenceSession(
|
||||
args.joiner_encoder_proj_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_joiner_decoder_proj(self, args):
|
||||
self.joiner_decoder_proj = ort.InferenceSession(
|
||||
args.joiner_decoder_proj_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def run_encoder(self, x, h0, c0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
|
@ -307,6 +307,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -312,6 +312,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -150,12 +150,14 @@ class OnnxModel:
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_decoder(self, decoder_model_filename: str):
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
@ -169,6 +171,7 @@ class OnnxModel:
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
|
@ -78,6 +78,7 @@ def test_conv2d_subsampling():
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
@ -133,6 +134,7 @@ def test_rel_pos():
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
@ -220,6 +222,7 @@ def test_conformer_encoder_layer():
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
@ -304,6 +307,7 @@ def test_conformer_encoder():
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
@ -359,6 +363,7 @@ def test_conformer():
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
|
@ -404,6 +404,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -335,6 +335,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -138,6 +138,7 @@ class OnnxModel:
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self.init_encoder_states()
|
||||
|
||||
@ -185,6 +186,7 @@ class OnnxModel:
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
@ -198,6 +200,7 @@ class OnnxModel:
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
|
@ -329,6 +329,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -74,6 +74,7 @@ def test_conv2d_subsampling():
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
@ -128,6 +129,7 @@ def test_rel_pos():
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
@ -204,6 +206,7 @@ def test_zipformer_encoder_layer():
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
@ -284,6 +287,7 @@ def test_zipformer_encoder():
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
@ -338,6 +342,7 @@ def test_zipformer():
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
|
@ -326,41 +326,49 @@ def main():
|
||||
encoder = ort.InferenceSession(
|
||||
args.encoder_model_filename,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder = ort.InferenceSession(
|
||||
args.decoder_model_filename,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner = ort.InferenceSession(
|
||||
args.joiner_model_filename,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_encoder_proj = ort.InferenceSession(
|
||||
args.joiner_encoder_proj_model_filename,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_decoder_proj = ort.InferenceSession(
|
||||
args.joiner_decoder_proj_model_filename,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
lconv = ort.InferenceSession(
|
||||
args.lconv_filename,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
frame_reducer = ort.InferenceSession(
|
||||
args.frame_reducer_filename,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
ctc_output = ort.InferenceSession(
|
||||
args.ctc_output_filename,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
|
@ -413,6 +413,7 @@ def export_decoder_model_onnx(
|
||||
context_size = decoder_model.decoder.context_size
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -401,6 +401,7 @@ def export_decoder_model_onnx(
|
||||
context_size = decoder_model.decoder.context_size
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -130,6 +130,7 @@ class OnnxModel:
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self.init_encoder_states()
|
||||
|
||||
@ -229,6 +230,7 @@ class OnnxModel:
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
@ -242,6 +244,7 @@ class OnnxModel:
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
|
@ -865,7 +865,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
return final_dropout_rate
|
||||
else:
|
||||
return initial_dropout_rate - (
|
||||
initial_dropout_rate * final_dropout_rate
|
||||
initial_dropout_rate - final_dropout_rate
|
||||
) * (self.batch_count / warmup_period)
|
||||
|
||||
def forward(
|
||||
|
@ -230,7 +230,7 @@ class Conformer(Transformer):
|
||||
x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask
|
||||
) # (T, B, F)
|
||||
else:
|
||||
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
|
||||
x = self.encoder(x, pos_emb, src_key_padding_mask=src_key_padding_mask) # (T, B, F)
|
||||
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
|
@ -506,6 +506,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -353,6 +353,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -146,6 +146,7 @@ class OnnxModel:
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self.init_encoder_states()
|
||||
|
||||
@ -236,6 +237,7 @@ class OnnxModel:
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
@ -249,6 +251,7 @@ class OnnxModel:
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
|
@ -151,12 +151,14 @@ class OnnxModel:
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_decoder(self, decoder_model_filename: str):
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
@ -170,6 +172,7 @@ class OnnxModel:
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
|
@ -315,6 +315,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -258,6 +258,7 @@ def main():
|
||||
encoder_session = ort.InferenceSession(
|
||||
args.onnx_encoder_filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
test_encoder(model, encoder_session)
|
||||
|
||||
@ -265,6 +266,7 @@ def main():
|
||||
decoder_session = ort.InferenceSession(
|
||||
args.onnx_decoder_filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
test_decoder(model, decoder_session)
|
||||
|
||||
@ -272,14 +274,17 @@ def main():
|
||||
joiner_session = ort.InferenceSession(
|
||||
args.onnx_joiner_filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
joiner_encoder_proj_session = ort.InferenceSession(
|
||||
args.onnx_joiner_encoder_proj_filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
joiner_decoder_proj_session = ort.InferenceSession(
|
||||
args.onnx_joiner_decoder_proj_filename,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
test_joiner(
|
||||
model,
|
||||
|
@ -404,6 +404,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -335,6 +335,7 @@ def export_decoder_model_onnx(
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
|
@ -139,6 +139,7 @@ class OnnxModel:
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self.init_encoder_states()
|
||||
|
||||
@ -186,6 +187,7 @@ class OnnxModel:
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
@ -199,6 +201,7 @@ class OnnxModel:
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
|
@ -158,12 +158,14 @@ class OnnxModel:
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_decoder(self, decoder_model_filename: str):
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
@ -177,6 +179,7 @@ class OnnxModel:
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
|
@ -54,6 +54,7 @@ class OnnxModel:
|
||||
self.model = ort.InferenceSession(
|
||||
nn_model,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
meta = self.model.get_modelmeta().custom_metadata_map
|
||||
|
Loading…
x
Reference in New Issue
Block a user