mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 15:14:18 +00:00
Update export-onnx.py for variable token counts
This commit is contained in:
parent
e5bb1ae86c
commit
278d637013
@ -149,6 +149,7 @@ class OnnxModel(nn.Module):
|
|||||||
def export_model_onnx(
|
def export_model_onnx(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
model_filename: str,
|
model_filename: str,
|
||||||
|
vocab_size: int,
|
||||||
opset_version: int = 11,
|
opset_version: int = 11,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Export the given generator model to ONNX format.
|
"""Export the given generator model to ONNX format.
|
||||||
@ -168,7 +169,7 @@ def export_model_onnx(
|
|||||||
opset_version:
|
opset_version:
|
||||||
The opset version to use.
|
The opset version to use.
|
||||||
"""
|
"""
|
||||||
tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64)
|
tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64)
|
||||||
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64)
|
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64)
|
||||||
noise_scale = torch.tensor([1], dtype=torch.float32)
|
noise_scale = torch.tensor([1], dtype=torch.float32)
|
||||||
noise_scale_dur = torch.tensor([1], dtype=torch.float32)
|
noise_scale_dur = torch.tensor([1], dtype=torch.float32)
|
||||||
@ -244,6 +245,7 @@ def main():
|
|||||||
export_model_onnx(
|
export_model_onnx(
|
||||||
model,
|
model,
|
||||||
model_filename,
|
model_filename,
|
||||||
|
params.vocab_size,
|
||||||
opset_version=opset_version,
|
opset_version=opset_version,
|
||||||
)
|
)
|
||||||
logging.info(f"Exported generator to {model_filename}")
|
logging.info(f"Exported generator to {model_filename}")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user