mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
Update export-onnx.py
This commit is contained in:
parent
e401a724ac
commit
fefae1f8df
@ -122,6 +122,7 @@ class OnnxModel(nn.Module):
|
||||
tokens_lens: torch.Tensor,
|
||||
noise_scale: float = 0.667,
|
||||
noise_scale_dur: float = 0.8,
|
||||
speaker: int = 20,
|
||||
alpha: float = 1.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Please see the help information of VITS.inference_batch
|
||||
@ -135,6 +136,8 @@ class OnnxModel(nn.Module):
|
||||
Noise scale parameter for flow.
|
||||
noise_scale_dur (float):
|
||||
Noise scale parameter for duration predictor.
|
||||
speaker (int):
|
||||
Speaker ID.
|
||||
alpha (float):
|
||||
Alpha parameter to control the speed of generated speech.
|
||||
|
||||
@ -147,6 +150,7 @@ class OnnxModel(nn.Module):
|
||||
text_lengths=tokens_lens,
|
||||
noise_scale=noise_scale,
|
||||
noise_scale_dur=noise_scale_dur,
|
||||
sids=speaker,
|
||||
alpha=alpha,
|
||||
)
|
||||
return audio
|
||||
@ -179,10 +183,11 @@ def export_model_onnx(
|
||||
noise_scale = torch.tensor([1], dtype=torch.float32)
|
||||
noise_scale_dur = torch.tensor([1], dtype=torch.float32)
|
||||
alpha = torch.tensor([1], dtype=torch.float32)
|
||||
speaker = torch.tensor([1], dtype=torch.int64)
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(tokens, tokens_lens, noise_scale, noise_scale_dur, alpha),
|
||||
(tokens, tokens_lens, noise_scale, noise_scale_dur, speaker, alpha),
|
||||
model_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
@ -191,6 +196,7 @@ def export_model_onnx(
|
||||
"tokens_lens",
|
||||
"noise_scale",
|
||||
"noise_scale_dur",
|
||||
"speaker",
|
||||
"alpha",
|
||||
],
|
||||
output_names=["audio"],
|
||||
@ -198,6 +204,7 @@ def export_model_onnx(
|
||||
"tokens": {0: "N", 1: "T"},
|
||||
"tokens_lens": {0: "N"},
|
||||
"audio": {0: "N", 1: "T"},
|
||||
"speaker": {0: "N"},
|
||||
},
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user