minor adjustments to the VITS recipes for onnx runtime (#1405)

This commit is contained in:
zr_jin 2023-12-08 06:32:40 +08:00 committed by GitHub
parent b87ed26c09
commit bda72f86ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 9 additions and 9 deletions

View File

@ -176,7 +176,7 @@ def export_model_onnx(
torch.onnx.export(
model,
(tokens, tokens_lens, noise_scale, noise_scale_dur, alpha),
(tokens, tokens_lens, noise_scale, alpha, noise_scale_dur),
model_filename,
verbose=False,
opset_version=opset_version,
@ -184,8 +184,8 @@ def export_model_onnx(
"tokens",
"tokens_lens",
"noise_scale",
"noise_scale_dur",
"alpha",
"noise_scale_dur",
],
output_names=["audio"],
dynamic_axes={

View File

@ -92,8 +92,8 @@ class OnnxModel:
self.model.get_inputs()[0].name: tokens.numpy(),
self.model.get_inputs()[1].name: tokens_lens.numpy(),
self.model.get_inputs()[2].name: noise_scale.numpy(),
self.model.get_inputs()[3].name: noise_scale_dur.numpy(),
self.model.get_inputs()[4].name: alpha.numpy(),
self.model.get_inputs()[3].name: alpha.numpy(),
self.model.get_inputs()[4].name: noise_scale_dur.numpy(),
},
)[0]
return torch.from_numpy(out)

View File

@ -187,7 +187,7 @@ def export_model_onnx(
torch.onnx.export(
model,
(tokens, tokens_lens, noise_scale, noise_scale_dur, speaker, alpha),
(tokens, tokens_lens, noise_scale, alpha, noise_scale_dur, speaker),
model_filename,
verbose=False,
opset_version=opset_version,
@ -195,9 +195,9 @@ def export_model_onnx(
"tokens",
"tokens_lens",
"noise_scale",
"alpha",
"noise_scale_dur",
"speaker",
"alpha",
],
output_names=["audio"],
dynamic_axes={

View File

@ -101,9 +101,9 @@ class OnnxModel:
self.model.get_inputs()[0].name: tokens.numpy(),
self.model.get_inputs()[1].name: tokens_lens.numpy(),
self.model.get_inputs()[2].name: noise_scale.numpy(),
self.model.get_inputs()[3].name: noise_scale_dur.numpy(),
self.model.get_inputs()[4].name: speaker.numpy(),
self.model.get_inputs()[5].name: alpha.numpy(),
self.model.get_inputs()[3].name: alpha.numpy(),
self.model.get_inputs()[4].name: noise_scale_dur.numpy(),
self.model.get_inputs()[5].name: speaker.numpy(),
},
)[0]
return torch.from_numpy(out)