From 278d6370133088a6acb75a15eb7da5073e53a3f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ali=20Haznedaro=C4=9Flu?= <53865510+ahazned@users.noreply.github.com> Date: Mon, 25 Dec 2023 13:48:03 +0300 Subject: [PATCH] Update export-onnx.py for variable token counts --- egs/ljspeech/TTS/vits/export-onnx.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py index 36a9de27f..ecf5341e3 100755 --- a/egs/ljspeech/TTS/vits/export-onnx.py +++ b/egs/ljspeech/TTS/vits/export-onnx.py @@ -149,6 +149,7 @@ class OnnxModel(nn.Module): def export_model_onnx( model: nn.Module, model_filename: str, + vocab_size: int, opset_version: int = 11, ) -> None: """Export the given generator model to ONNX format. @@ -168,7 +169,7 @@ def export_model_onnx( opset_version: The opset version to use. """ - tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64) + tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64) tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) noise_scale = torch.tensor([1], dtype=torch.float32) noise_scale_dur = torch.tensor([1], dtype=torch.float32) @@ -244,6 +245,7 @@ def main(): export_model_onnx( model, model_filename, + params.vocab_size, opset_version=opset_version, ) logging.info(f"Exported generator to {model_filename}")