Update export-onnx.py for variable token counts

This commit is contained in:
Ali Haznedaroğlu 2023-12-25 13:48:03 +03:00 committed by GitHub
parent e5bb1ae86c
commit 278d637013
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -149,6 +149,7 @@ class OnnxModel(nn.Module):
def export_model_onnx(
model: nn.Module,
model_filename: str,
vocab_size: int,
opset_version: int = 11,
) -> None:
"""Export the given generator model to ONNX format.
@ -168,7 +169,7 @@ def export_model_onnx(
opset_version:
The opset version to use.
"""
tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64)
tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64)
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64)
noise_scale = torch.tensor([1], dtype=torch.float32)
noise_scale_dur = torch.tensor([1], dtype=torch.float32)
@ -244,6 +245,7 @@ def main():
export_model_onnx(
model,
model_filename,
params.vocab_size,
opset_version=opset_version,
)
logging.info(f"Exported generator to {model_filename}")