mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
allow export of onnx-streaming-models with other than 80dim input features (#1556)
This commit is contained in:
parent
eec12f053d
commit
4917ac8bab
@ -333,6 +333,7 @@ def export_encoder_model_onnx(
|
||||
encoder_model: OnnxEncoder,
|
||||
encoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
feature_dim: int = 80,
|
||||
) -> None:
|
||||
encoder_model.encoder.__class__.forward = (
|
||||
encoder_model.encoder.__class__.streaming_forward
|
||||
@ -343,7 +344,7 @@ def export_encoder_model_onnx(
|
||||
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
|
||||
T = decode_chunk_len + encoder_model.pad_length
|
||||
|
||||
x = torch.rand(1, T, 80, dtype=torch.float32)
|
||||
x = torch.rand(1, T, feature_dim, dtype=torch.float32)
|
||||
init_state = encoder_model.get_init_states()
|
||||
num_encoders = len(encoder_model.encoder.encoder_dim)
|
||||
logging.info(f"num_encoders: {num_encoders}")
|
||||
@ -724,6 +725,7 @@ def main():
|
||||
encoder,
|
||||
encoder_filename,
|
||||
opset_version=opset_version,
|
||||
feature_dim=params.feature_dim,
|
||||
)
|
||||
logging.info(f"Exported encoder to {encoder_filename}")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user