mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Update TTS export-onnx.py scripts for handling variable token counts (#1430)
This commit is contained in:
parent
c855a58cfd
commit
ddd7131317
@ -149,6 +149,7 @@ class OnnxModel(nn.Module):
|
||||
def export_model_onnx(
|
||||
model: nn.Module,
|
||||
model_filename: str,
|
||||
vocab_size: int,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the given generator model to ONNX format.
|
||||
@ -165,10 +166,12 @@ def export_model_onnx(
|
||||
The VITS generator.
|
||||
model_filename:
|
||||
The filename to save the exported ONNX model.
|
||||
vocab_size:
|
||||
Number of tokens used in training.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64)
|
||||
tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64)
|
||||
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64)
|
||||
noise_scale = torch.tensor([1], dtype=torch.float32)
|
||||
noise_scale_dur = torch.tensor([1], dtype=torch.float32)
|
||||
@ -244,6 +247,7 @@ def main():
|
||||
export_model_onnx(
|
||||
model,
|
||||
model_filename,
|
||||
params.vocab_size,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported generator to {model_filename}")
|
||||
|
@ -159,6 +159,7 @@ class OnnxModel(nn.Module):
|
||||
def export_model_onnx(
|
||||
model: nn.Module,
|
||||
model_filename: str,
|
||||
vocab_size: int,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the given generator model to ONNX format.
|
||||
@ -175,10 +176,12 @@ def export_model_onnx(
|
||||
The VITS generator.
|
||||
model_filename:
|
||||
The filename to save the exported ONNX model.
|
||||
vocab_size:
|
||||
Number of tokens used in training.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64)
|
||||
tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64)
|
||||
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64)
|
||||
noise_scale = torch.tensor([1], dtype=torch.float32)
|
||||
noise_scale_dur = torch.tensor([1], dtype=torch.float32)
|
||||
@ -261,6 +264,7 @@ def main():
|
||||
export_model_onnx(
|
||||
model,
|
||||
model_filename,
|
||||
params.vocab_size,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported generator to {model_filename}")
|
||||
|
Loading…
x
Reference in New Issue
Block a user