stream onnx export suc

This commit is contained in:
danqing fu 2023-06-08 13:33:38 +08:00
parent 1378f833bd
commit 0b63b87092

View File

@ -347,13 +347,57 @@ def export_encoder_model_onnx(
outputs = {}
output_names = ["encoder_out"]
def build_inputs_outputs(tensors, name, N):
for i, s in enumerate(tensors):
logging.info(f"{name}_{i}.shape: {s.shape}")
inputs[f"{name}_{i}"] = {N: "N"}
outputs[f"new_{name}_{i}"] = {N: "N"}
input_names.append(f"{name}_{i}")
output_names.append(f"new_{name}_{i}")
def build_inputs_outputs(tensors, i):
assert len(tensors) == 6, len(tensors)
# (downsample_left, batch_size, key_dim)
name = f'cached_key_{i}'
logging.info(f"{name}.shape: {tensors[0].shape}")
inputs[f"{name}"] = {1: "N"}
outputs[f"new_{name}"] = {1: "N"}
input_names.append(f"{name}")
output_names.append(f"new_{name}")
# (1, batch_size, downsample_left, nonlin_attn_head_dim)
name = f'cached_nonlin_attn_{i}'
logging.info(f"{name}.shape: {tensors[1].shape}")
inputs[f"{name}"] = {1: "N"}
outputs[f"new_{name}"] = {1: "N"}
input_names.append(f"{name}")
output_names.append(f"new_{name}")
# (downsample_left, batch_size, value_dim)
name = f'cached_val1_{i}'
logging.info(f"{name}.shape: {tensors[2].shape}")
inputs[f"{name}"] = {1: "N"}
outputs[f"new_{name}"] = {1: "N"}
input_names.append(f"{name}")
output_names.append(f"new_{name}")
# (downsample_left, batch_size, value_dim)
name = f'cached_val2_{i}'
logging.info(f"{name}.shape: {tensors[3].shape}")
inputs[f"{name}"] = {1: "N"}
outputs[f"new_{name}"] = {1: "N"}
input_names.append(f"{name}")
output_names.append(f"new_{name}")
# (batch_size, embed_dim, conv_left_pad)
name = f'cached_conv1_{i}'
logging.info(f"{name}.shape: {tensors[4].shape}")
inputs[f"{name}"] = {0: "N"}
outputs[f"new_{name}"] = {0: "N"}
input_names.append(f"{name}")
output_names.append(f"new_{name}")
# (batch_size, embed_dim, conv_left_pad)
name = f'cached_conv2_{i}'
logging.info(f"{name}.shape: {tensors[5].shape}")
inputs[f"{name}"] = {0: "N"}
outputs[f"new_{name}"] = {0: "N"}
input_names.append(f"{name}")
output_names.append(f"new_{name}")
num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers))
encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dim))
@ -376,30 +420,9 @@ def export_encoder_model_onnx(
}
logging.info(f"meta_data: {meta_data}")
# (num_encoder_layers, left_context_len, 1, key_dim)
cached_key = init_state[0:-2:6]
for i in range(len(init_state[:-2]) // 6):
build_inputs_outputs(init_state[i*6:(i+1)*6], i)
# (num_encoder_layers, 1, 1, left_context_len, nonlin_attn_head_dim)
cached_nonlin_attn = init_state[1:-2:6]
# (num_encoder_layers, left_context_len, 1, attention_dim//2)
cached_val1 = init_state[2:-2:6]
# (num_encoder_layers, left_context_len, 1, attention_dim//2)
cached_val2 = init_state[3:-2:6]
# (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1)
cached_conv1 = init_state[4:-2:6]
# (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1)
cached_conv2 = init_state[5:-2:6]
build_inputs_outputs(cached_key, "cached_key", 1)
build_inputs_outputs(cached_nonlin_attn, "cached_nonlin_attn", 1)
build_inputs_outputs(cached_val1, "cached_val1", 1)
build_inputs_outputs(cached_val2, "cached_val2", 1)
build_inputs_outputs(cached_conv1, "cached_conv1", 0)
build_inputs_outputs(cached_conv2, "cached_conv2", 0)
# (batch_size, channels, left_pad, freq)
embed_states = init_state[-2]
@ -424,13 +447,11 @@ def export_encoder_model_onnx(
logging.info(input_names)
logging.info(output_names)
encoder_model = torch.jit.trace(encoder_model, (x, init_state))
torch.onnx.export(
encoder_model,
(x, init_state),
encoder_filename,
verbose=True,
verbose=False,
opset_version=opset_version,
input_names=input_names,
output_names=output_names,
@ -653,7 +674,7 @@ def main():
model.to("cpu")
model.eval()
convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
convert_scaled_to_non_scaled(model, inplace=True)
encoder = OnnxEncoder(
encoder=model.encoder,