Update TTS export-onnx.py scripts for handling variable token counts (#1430)

This commit is contained in:
Ali Haznedaroğlu 2023-12-25 14:44:07 +03:00 committed by GitHub
parent c855a58cfd
commit ddd7131317
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 2 deletions

View File

@ -149,6 +149,7 @@ class OnnxModel(nn.Module):
def export_model_onnx(
model: nn.Module,
model_filename: str,
vocab_size: int,
opset_version: int = 11,
) -> None:
"""Export the given generator model to ONNX format.
@ -165,10 +166,12 @@ def export_model_onnx(
The VITS generator.
model_filename:
The filename to save the exported ONNX model.
vocab_size:
Number of tokens used in training.
opset_version:
The opset version to use.
"""
tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64)
tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64)
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64)
noise_scale = torch.tensor([1], dtype=torch.float32)
noise_scale_dur = torch.tensor([1], dtype=torch.float32)
@ -244,6 +247,7 @@ def main():
export_model_onnx(
model,
model_filename,
params.vocab_size,
opset_version=opset_version,
)
logging.info(f"Exported generator to {model_filename}")

View File

@ -159,6 +159,7 @@ class OnnxModel(nn.Module):
def export_model_onnx(
model: nn.Module,
model_filename: str,
vocab_size: int,
opset_version: int = 11,
) -> None:
"""Export the given generator model to ONNX format.
@ -175,10 +176,12 @@ def export_model_onnx(
The VITS generator.
model_filename:
The filename to save the exported ONNX model.
vocab_size:
Number of tokens used in training.
opset_version:
The opset version to use.
"""
tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64)
tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64)
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64)
noise_scale = torch.tensor([1], dtype=torch.float32)
noise_scale_dur = torch.tensor([1], dtype=torch.float32)
@ -261,6 +264,7 @@ def main():
export_model_onnx(
model,
model_filename,
params.vocab_size,
opset_version=opset_version,
)
logging.info(f"Exported generator to {model_filename}")