Update export-onnx.py

This commit is contained in:
jinzr 2023-12-05 09:24:40 +08:00
parent e401a724ac
commit fefae1f8df

View File

@ -122,6 +122,7 @@ class OnnxModel(nn.Module):
tokens_lens: torch.Tensor, tokens_lens: torch.Tensor,
noise_scale: float = 0.667, noise_scale: float = 0.667,
noise_scale_dur: float = 0.8, noise_scale_dur: float = 0.8,
speaker: int = 20,
alpha: float = 1.0, alpha: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Please see the help information of VITS.inference_batch """Please see the help information of VITS.inference_batch
@ -135,6 +136,8 @@ class OnnxModel(nn.Module):
Noise scale parameter for flow. Noise scale parameter for flow.
noise_scale_dur (float): noise_scale_dur (float):
Noise scale parameter for duration predictor. Noise scale parameter for duration predictor.
speaker (int):
Speaker ID.
alpha (float): alpha (float):
Alpha parameter to control the speed of generated speech. Alpha parameter to control the speed of generated speech.
@ -147,6 +150,7 @@ class OnnxModel(nn.Module):
text_lengths=tokens_lens, text_lengths=tokens_lens,
noise_scale=noise_scale, noise_scale=noise_scale,
noise_scale_dur=noise_scale_dur, noise_scale_dur=noise_scale_dur,
sids=speaker,
alpha=alpha, alpha=alpha,
) )
return audio return audio
@ -179,10 +183,11 @@ def export_model_onnx(
noise_scale = torch.tensor([1], dtype=torch.float32) noise_scale = torch.tensor([1], dtype=torch.float32)
noise_scale_dur = torch.tensor([1], dtype=torch.float32) noise_scale_dur = torch.tensor([1], dtype=torch.float32)
alpha = torch.tensor([1], dtype=torch.float32) alpha = torch.tensor([1], dtype=torch.float32)
speaker = torch.tensor([1], dtype=torch.int64)
torch.onnx.export( torch.onnx.export(
model, model,
(tokens, tokens_lens, noise_scale, noise_scale_dur, alpha), (tokens, tokens_lens, noise_scale, noise_scale_dur, speaker, alpha),
model_filename, model_filename,
verbose=False, verbose=False,
opset_version=opset_version, opset_version=opset_version,
@ -191,6 +196,7 @@ def export_model_onnx(
"tokens_lens", "tokens_lens",
"noise_scale", "noise_scale",
"noise_scale_dur", "noise_scale_dur",
"speaker",
"alpha", "alpha",
], ],
output_names=["audio"], output_names=["audio"],
@ -198,6 +204,7 @@ def export_model_onnx(
"tokens": {0: "N", 1: "T"}, "tokens": {0: "N", 1: "T"},
"tokens_lens": {0: "N"}, "tokens_lens": {0: "N"},
"audio": {0: "N", 1: "T"}, "audio": {0: "N", 1: "T"},
"speaker": {0: "N"},
}, },
) )