allow export of onnx-streaming-models with other than 80dim input features (#1556)

This commit is contained in:
Karel Vesely 2024-03-18 11:43:29 +01:00 committed by GitHub
parent eec12f053d
commit 4917ac8bab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -333,6 +333,7 @@ def export_encoder_model_onnx(
encoder_model: OnnxEncoder,
encoder_filename: str,
opset_version: int = 11,
feature_dim: int = 80,
) -> None:
encoder_model.encoder.__class__.forward = (
encoder_model.encoder.__class__.streaming_forward
@ -343,7 +344,7 @@ def export_encoder_model_onnx(
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
T = decode_chunk_len + encoder_model.pad_length
x = torch.rand(1, T, 80, dtype=torch.float32)
x = torch.rand(1, T, feature_dim, dtype=torch.float32)
init_state = encoder_model.get_init_states()
num_encoders = len(encoder_model.encoder.encoder_dim)
logging.info(f"num_encoders: {num_encoders}")
@ -724,6 +725,7 @@ def main():
encoder,
encoder_filename,
opset_version=opset_version,
feature_dim=params.feature_dim,
)
logging.info(f"Exported encoder to {encoder_filename}")