Update export-onnx.py

This commit is contained in:
jinzr 2023-12-05 09:24:40 +08:00
parent e401a724ac
commit fefae1f8df

View File

@ -122,6 +122,7 @@ class OnnxModel(nn.Module):
tokens_lens: torch.Tensor,
noise_scale: float = 0.667,
noise_scale_dur: float = 0.8,
speaker: int = 20,
alpha: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Please see the help information of VITS.inference_batch
@ -135,6 +136,8 @@ class OnnxModel(nn.Module):
Noise scale parameter for flow.
noise_scale_dur (float):
Noise scale parameter for duration predictor.
speaker (int):
Speaker ID.
alpha (float):
Alpha parameter to control the speed of generated speech.
@ -147,6 +150,7 @@ class OnnxModel(nn.Module):
text_lengths=tokens_lens,
noise_scale=noise_scale,
noise_scale_dur=noise_scale_dur,
sids=speaker,
alpha=alpha,
)
return audio
@ -179,10 +183,11 @@ def export_model_onnx(
noise_scale = torch.tensor([1], dtype=torch.float32)
noise_scale_dur = torch.tensor([1], dtype=torch.float32)
alpha = torch.tensor([1], dtype=torch.float32)
speaker = torch.tensor([1], dtype=torch.int64)
torch.onnx.export(
model,
(tokens, tokens_lens, noise_scale, noise_scale_dur, alpha),
(tokens, tokens_lens, noise_scale, noise_scale_dur, speaker, alpha),
model_filename,
verbose=False,
opset_version=opset_version,
@ -191,6 +196,7 @@ def export_model_onnx(
"tokens_lens",
"noise_scale",
"noise_scale_dur",
"speaker",
"alpha",
],
output_names=["audio"],
@ -198,6 +204,7 @@ def export_model_onnx(
"tokens": {0: "N", 1: "T"},
"tokens_lens": {0: "N"},
"audio": {0: "N", 1: "T"},
"speaker": {0: "N"},
},
)