mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
allow export of onnx-streaming-models with other than 80dim input features (#1556)
This commit is contained in:
parent
eec12f053d
commit
4917ac8bab
@ -333,6 +333,7 @@ def export_encoder_model_onnx(
|
|||||||
encoder_model: OnnxEncoder,
|
encoder_model: OnnxEncoder,
|
||||||
encoder_filename: str,
|
encoder_filename: str,
|
||||||
opset_version: int = 11,
|
opset_version: int = 11,
|
||||||
|
feature_dim: int = 80,
|
||||||
) -> None:
|
) -> None:
|
||||||
encoder_model.encoder.__class__.forward = (
|
encoder_model.encoder.__class__.forward = (
|
||||||
encoder_model.encoder.__class__.streaming_forward
|
encoder_model.encoder.__class__.streaming_forward
|
||||||
@ -343,7 +344,7 @@ def export_encoder_model_onnx(
|
|||||||
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
|
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
|
||||||
T = decode_chunk_len + encoder_model.pad_length
|
T = decode_chunk_len + encoder_model.pad_length
|
||||||
|
|
||||||
x = torch.rand(1, T, 80, dtype=torch.float32)
|
x = torch.rand(1, T, feature_dim, dtype=torch.float32)
|
||||||
init_state = encoder_model.get_init_states()
|
init_state = encoder_model.get_init_states()
|
||||||
num_encoders = len(encoder_model.encoder.encoder_dim)
|
num_encoders = len(encoder_model.encoder.encoder_dim)
|
||||||
logging.info(f"num_encoders: {num_encoders}")
|
logging.info(f"num_encoders: {num_encoders}")
|
||||||
@ -724,6 +725,7 @@ def main():
|
|||||||
encoder,
|
encoder,
|
||||||
encoder_filename,
|
encoder_filename,
|
||||||
opset_version=opset_version,
|
opset_version=opset_version,
|
||||||
|
feature_dim=params.feature_dim,
|
||||||
)
|
)
|
||||||
logging.info(f"Exported encoder to {encoder_filename}")
|
logging.info(f"Exported encoder to {encoder_filename}")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user