mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 01:22:22 +00:00
minor adjustments to the VITS recipes for onnx runtime (#1405)
This commit is contained in:
parent
b87ed26c09
commit
bda72f86ff
@ -176,7 +176,7 @@ def export_model_onnx(
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(tokens, tokens_lens, noise_scale, noise_scale_dur, alpha),
|
||||
(tokens, tokens_lens, noise_scale, alpha, noise_scale_dur),
|
||||
model_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
@ -184,8 +184,8 @@ def export_model_onnx(
|
||||
"tokens",
|
||||
"tokens_lens",
|
||||
"noise_scale",
|
||||
"noise_scale_dur",
|
||||
"alpha",
|
||||
"noise_scale_dur",
|
||||
],
|
||||
output_names=["audio"],
|
||||
dynamic_axes={
|
||||
|
@ -92,8 +92,8 @@ class OnnxModel:
|
||||
self.model.get_inputs()[0].name: tokens.numpy(),
|
||||
self.model.get_inputs()[1].name: tokens_lens.numpy(),
|
||||
self.model.get_inputs()[2].name: noise_scale.numpy(),
|
||||
self.model.get_inputs()[3].name: noise_scale_dur.numpy(),
|
||||
self.model.get_inputs()[4].name: alpha.numpy(),
|
||||
self.model.get_inputs()[3].name: alpha.numpy(),
|
||||
self.model.get_inputs()[4].name: noise_scale_dur.numpy(),
|
||||
},
|
||||
)[0]
|
||||
return torch.from_numpy(out)
|
||||
|
@ -187,7 +187,7 @@ def export_model_onnx(
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(tokens, tokens_lens, noise_scale, noise_scale_dur, speaker, alpha),
|
||||
(tokens, tokens_lens, noise_scale, alpha, noise_scale_dur, speaker),
|
||||
model_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
@ -195,9 +195,9 @@ def export_model_onnx(
|
||||
"tokens",
|
||||
"tokens_lens",
|
||||
"noise_scale",
|
||||
"alpha",
|
||||
"noise_scale_dur",
|
||||
"speaker",
|
||||
"alpha",
|
||||
],
|
||||
output_names=["audio"],
|
||||
dynamic_axes={
|
||||
|
@ -101,9 +101,9 @@ class OnnxModel:
|
||||
self.model.get_inputs()[0].name: tokens.numpy(),
|
||||
self.model.get_inputs()[1].name: tokens_lens.numpy(),
|
||||
self.model.get_inputs()[2].name: noise_scale.numpy(),
|
||||
self.model.get_inputs()[3].name: noise_scale_dur.numpy(),
|
||||
self.model.get_inputs()[4].name: speaker.numpy(),
|
||||
self.model.get_inputs()[5].name: alpha.numpy(),
|
||||
self.model.get_inputs()[3].name: alpha.numpy(),
|
||||
self.model.get_inputs()[4].name: noise_scale_dur.numpy(),
|
||||
self.model.get_inputs()[5].name: speaker.numpy(),
|
||||
},
|
||||
)[0]
|
||||
return torch.from_numpy(out)
|
||||
|
Loading…
x
Reference in New Issue
Block a user