mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 22:24:19 +00:00
Update export-onnx.py for variable token counts
This commit is contained in:
parent
e5bb1ae86c
commit
278d637013
@ -149,6 +149,7 @@ class OnnxModel(nn.Module):
|
||||
def export_model_onnx(
|
||||
model: nn.Module,
|
||||
model_filename: str,
|
||||
vocab_size: int,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the given generator model to ONNX format.
|
||||
@ -168,7 +169,7 @@ def export_model_onnx(
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64)
|
||||
tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64)
|
||||
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64)
|
||||
noise_scale = torch.tensor([1], dtype=torch.float32)
|
||||
noise_scale_dur = torch.tensor([1], dtype=torch.float32)
|
||||
@ -244,6 +245,7 @@ def main():
|
||||
export_model_onnx(
|
||||
model,
|
||||
model_filename,
|
||||
params.vocab_size,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported generator to {model_filename}")
|
||||
|
Loading…
x
Reference in New Issue
Block a user