mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
stream onnx export suc
This commit is contained in:
parent
1378f833bd
commit
0b63b87092
@ -347,13 +347,57 @@ def export_encoder_model_onnx(
|
||||
|
||||
outputs = {}
|
||||
output_names = ["encoder_out"]
|
||||
def build_inputs_outputs(tensors, name, N):
|
||||
for i, s in enumerate(tensors):
|
||||
logging.info(f"{name}_{i}.shape: {s.shape}")
|
||||
inputs[f"{name}_{i}"] = {N: "N"}
|
||||
outputs[f"new_{name}_{i}"] = {N: "N"}
|
||||
input_names.append(f"{name}_{i}")
|
||||
output_names.append(f"new_{name}_{i}")
|
||||
def build_inputs_outputs(tensors, i):
|
||||
assert len(tensors) == 6, len(tensors)
|
||||
|
||||
# (downsample_left, batch_size, key_dim)
|
||||
name = f'cached_key_{i}'
|
||||
logging.info(f"{name}.shape: {tensors[0].shape}")
|
||||
inputs[f"{name}"] = {1: "N"}
|
||||
outputs[f"new_{name}"] = {1: "N"}
|
||||
input_names.append(f"{name}")
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
# (1, batch_size, downsample_left, nonlin_attn_head_dim)
|
||||
name = f'cached_nonlin_attn_{i}'
|
||||
logging.info(f"{name}.shape: {tensors[1].shape}")
|
||||
inputs[f"{name}"] = {1: "N"}
|
||||
outputs[f"new_{name}"] = {1: "N"}
|
||||
input_names.append(f"{name}")
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
# (downsample_left, batch_size, value_dim)
|
||||
name = f'cached_val1_{i}'
|
||||
logging.info(f"{name}.shape: {tensors[2].shape}")
|
||||
inputs[f"{name}"] = {1: "N"}
|
||||
outputs[f"new_{name}"] = {1: "N"}
|
||||
input_names.append(f"{name}")
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
# (downsample_left, batch_size, value_dim)
|
||||
name = f'cached_val2_{i}'
|
||||
logging.info(f"{name}.shape: {tensors[3].shape}")
|
||||
inputs[f"{name}"] = {1: "N"}
|
||||
outputs[f"new_{name}"] = {1: "N"}
|
||||
input_names.append(f"{name}")
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
# (batch_size, embed_dim, conv_left_pad)
|
||||
name = f'cached_conv1_{i}'
|
||||
logging.info(f"{name}.shape: {tensors[4].shape}")
|
||||
inputs[f"{name}"] = {0: "N"}
|
||||
outputs[f"new_{name}"] = {0: "N"}
|
||||
input_names.append(f"{name}")
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
# (batch_size, embed_dim, conv_left_pad)
|
||||
name = f'cached_conv2_{i}'
|
||||
logging.info(f"{name}.shape: {tensors[5].shape}")
|
||||
inputs[f"{name}"] = {0: "N"}
|
||||
outputs[f"new_{name}"] = {0: "N"}
|
||||
input_names.append(f"{name}")
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
|
||||
num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers))
|
||||
encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dim))
|
||||
@ -376,30 +420,9 @@ def export_encoder_model_onnx(
|
||||
}
|
||||
logging.info(f"meta_data: {meta_data}")
|
||||
|
||||
# (num_encoder_layers, left_context_len, 1, key_dim)
|
||||
cached_key = init_state[0:-2:6]
|
||||
for i in range(len(init_state[:-2]) // 6):
|
||||
build_inputs_outputs(init_state[i*6:(i+1)*6], i)
|
||||
|
||||
# (num_encoder_layers, 1, 1, left_context_len, nonlin_attn_head_dim)
|
||||
cached_nonlin_attn = init_state[1:-2:6]
|
||||
|
||||
# (num_encoder_layers, left_context_len, 1, attention_dim//2)
|
||||
cached_val1 = init_state[2:-2:6]
|
||||
|
||||
# (num_encoder_layers, left_context_len, 1, attention_dim//2)
|
||||
cached_val2 = init_state[3:-2:6]
|
||||
|
||||
# (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1)
|
||||
cached_conv1 = init_state[4:-2:6]
|
||||
|
||||
# (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1)
|
||||
cached_conv2 = init_state[5:-2:6]
|
||||
|
||||
build_inputs_outputs(cached_key, "cached_key", 1)
|
||||
build_inputs_outputs(cached_nonlin_attn, "cached_nonlin_attn", 1)
|
||||
build_inputs_outputs(cached_val1, "cached_val1", 1)
|
||||
build_inputs_outputs(cached_val2, "cached_val2", 1)
|
||||
build_inputs_outputs(cached_conv1, "cached_conv1", 0)
|
||||
build_inputs_outputs(cached_conv2, "cached_conv2", 0)
|
||||
|
||||
# (batch_size, channels, left_pad, freq)
|
||||
embed_states = init_state[-2]
|
||||
@ -424,13 +447,11 @@ def export_encoder_model_onnx(
|
||||
logging.info(input_names)
|
||||
logging.info(output_names)
|
||||
|
||||
encoder_model = torch.jit.trace(encoder_model, (x, init_state))
|
||||
|
||||
torch.onnx.export(
|
||||
encoder_model,
|
||||
(x, init_state),
|
||||
encoder_filename,
|
||||
verbose=True,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
@ -653,7 +674,7 @@ def main():
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
|
||||
encoder = OnnxEncoder(
|
||||
encoder=model.encoder,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user