minor adjustments to the VITS recipes for onnx runtime (#1405)

This commit is contained in:
zr_jin 2023-12-08 06:32:40 +08:00 committed by GitHub
parent b87ed26c09
commit bda72f86ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 9 additions and 9 deletions

View File

@ -176,7 +176,7 @@ def export_model_onnx(
torch.onnx.export( torch.onnx.export(
model, model,
(tokens, tokens_lens, noise_scale, noise_scale_dur, alpha), (tokens, tokens_lens, noise_scale, alpha, noise_scale_dur),
model_filename, model_filename,
verbose=False, verbose=False,
opset_version=opset_version, opset_version=opset_version,
@ -184,8 +184,8 @@ def export_model_onnx(
"tokens", "tokens",
"tokens_lens", "tokens_lens",
"noise_scale", "noise_scale",
"noise_scale_dur",
"alpha", "alpha",
"noise_scale_dur",
], ],
output_names=["audio"], output_names=["audio"],
dynamic_axes={ dynamic_axes={

View File

@ -92,8 +92,8 @@ class OnnxModel:
self.model.get_inputs()[0].name: tokens.numpy(), self.model.get_inputs()[0].name: tokens.numpy(),
self.model.get_inputs()[1].name: tokens_lens.numpy(), self.model.get_inputs()[1].name: tokens_lens.numpy(),
self.model.get_inputs()[2].name: noise_scale.numpy(), self.model.get_inputs()[2].name: noise_scale.numpy(),
self.model.get_inputs()[3].name: noise_scale_dur.numpy(), self.model.get_inputs()[3].name: alpha.numpy(),
self.model.get_inputs()[4].name: alpha.numpy(), self.model.get_inputs()[4].name: noise_scale_dur.numpy(),
}, },
)[0] )[0]
return torch.from_numpy(out) return torch.from_numpy(out)

View File

@ -187,7 +187,7 @@ def export_model_onnx(
torch.onnx.export( torch.onnx.export(
model, model,
(tokens, tokens_lens, noise_scale, noise_scale_dur, speaker, alpha), (tokens, tokens_lens, noise_scale, alpha, noise_scale_dur, speaker),
model_filename, model_filename,
verbose=False, verbose=False,
opset_version=opset_version, opset_version=opset_version,
@ -195,9 +195,9 @@ def export_model_onnx(
"tokens", "tokens",
"tokens_lens", "tokens_lens",
"noise_scale", "noise_scale",
"alpha",
"noise_scale_dur", "noise_scale_dur",
"speaker", "speaker",
"alpha",
], ],
output_names=["audio"], output_names=["audio"],
dynamic_axes={ dynamic_axes={

View File

@ -101,9 +101,9 @@ class OnnxModel:
self.model.get_inputs()[0].name: tokens.numpy(), self.model.get_inputs()[0].name: tokens.numpy(),
self.model.get_inputs()[1].name: tokens_lens.numpy(), self.model.get_inputs()[1].name: tokens_lens.numpy(),
self.model.get_inputs()[2].name: noise_scale.numpy(), self.model.get_inputs()[2].name: noise_scale.numpy(),
self.model.get_inputs()[3].name: noise_scale_dur.numpy(), self.model.get_inputs()[3].name: alpha.numpy(),
self.model.get_inputs()[4].name: speaker.numpy(), self.model.get_inputs()[4].name: noise_scale_dur.numpy(),
self.model.get_inputs()[5].name: alpha.numpy(), self.model.get_inputs()[5].name: speaker.numpy(),
}, },
)[0] )[0]
return torch.from_numpy(out) return torch.from_numpy(out)