mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 23:54:17 +00:00
Update export-onnx.py
This commit is contained in:
parent
e401a724ac
commit
fefae1f8df
@ -122,6 +122,7 @@ class OnnxModel(nn.Module):
|
|||||||
tokens_lens: torch.Tensor,
|
tokens_lens: torch.Tensor,
|
||||||
noise_scale: float = 0.667,
|
noise_scale: float = 0.667,
|
||||||
noise_scale_dur: float = 0.8,
|
noise_scale_dur: float = 0.8,
|
||||||
|
speaker: int = 20,
|
||||||
alpha: float = 1.0,
|
alpha: float = 1.0,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Please see the help information of VITS.inference_batch
|
"""Please see the help information of VITS.inference_batch
|
||||||
@ -135,6 +136,8 @@ class OnnxModel(nn.Module):
|
|||||||
Noise scale parameter for flow.
|
Noise scale parameter for flow.
|
||||||
noise_scale_dur (float):
|
noise_scale_dur (float):
|
||||||
Noise scale parameter for duration predictor.
|
Noise scale parameter for duration predictor.
|
||||||
|
speaker (int):
|
||||||
|
Speaker ID.
|
||||||
alpha (float):
|
alpha (float):
|
||||||
Alpha parameter to control the speed of generated speech.
|
Alpha parameter to control the speed of generated speech.
|
||||||
|
|
||||||
@ -147,6 +150,7 @@ class OnnxModel(nn.Module):
|
|||||||
text_lengths=tokens_lens,
|
text_lengths=tokens_lens,
|
||||||
noise_scale=noise_scale,
|
noise_scale=noise_scale,
|
||||||
noise_scale_dur=noise_scale_dur,
|
noise_scale_dur=noise_scale_dur,
|
||||||
|
sids=speaker,
|
||||||
alpha=alpha,
|
alpha=alpha,
|
||||||
)
|
)
|
||||||
return audio
|
return audio
|
||||||
@ -179,10 +183,11 @@ def export_model_onnx(
|
|||||||
noise_scale = torch.tensor([1], dtype=torch.float32)
|
noise_scale = torch.tensor([1], dtype=torch.float32)
|
||||||
noise_scale_dur = torch.tensor([1], dtype=torch.float32)
|
noise_scale_dur = torch.tensor([1], dtype=torch.float32)
|
||||||
alpha = torch.tensor([1], dtype=torch.float32)
|
alpha = torch.tensor([1], dtype=torch.float32)
|
||||||
|
speaker = torch.tensor([1], dtype=torch.int64)
|
||||||
|
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
model,
|
model,
|
||||||
(tokens, tokens_lens, noise_scale, noise_scale_dur, alpha),
|
(tokens, tokens_lens, noise_scale, noise_scale_dur, speaker, alpha),
|
||||||
model_filename,
|
model_filename,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
opset_version=opset_version,
|
opset_version=opset_version,
|
||||||
@ -191,6 +196,7 @@ def export_model_onnx(
|
|||||||
"tokens_lens",
|
"tokens_lens",
|
||||||
"noise_scale",
|
"noise_scale",
|
||||||
"noise_scale_dur",
|
"noise_scale_dur",
|
||||||
|
"speaker",
|
||||||
"alpha",
|
"alpha",
|
||||||
],
|
],
|
||||||
output_names=["audio"],
|
output_names=["audio"],
|
||||||
@ -198,6 +204,7 @@ def export_model_onnx(
|
|||||||
"tokens": {0: "N", 1: "T"},
|
"tokens": {0: "N", 1: "T"},
|
||||||
"tokens_lens": {0: "N"},
|
"tokens_lens": {0: "N"},
|
||||||
"audio": {0: "N", 1: "T"},
|
"audio": {0: "N", 1: "T"},
|
||||||
|
"speaker": {0: "N"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user