mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
stream onnx export suc
This commit is contained in:
parent
1378f833bd
commit
0b63b87092
@ -269,7 +269,7 @@ class OnnxEncoder(nn.Module):
|
|||||||
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||||
states.append(processed_lens)
|
states.append(processed_lens)
|
||||||
|
|
||||||
return states
|
return states
|
||||||
|
|
||||||
|
|
||||||
class OnnxDecoder(nn.Module):
|
class OnnxDecoder(nn.Module):
|
||||||
@ -347,13 +347,57 @@ def export_encoder_model_onnx(
|
|||||||
|
|
||||||
outputs = {}
|
outputs = {}
|
||||||
output_names = ["encoder_out"]
|
output_names = ["encoder_out"]
|
||||||
def build_inputs_outputs(tensors, name, N):
|
def build_inputs_outputs(tensors, i):
|
||||||
for i, s in enumerate(tensors):
|
assert len(tensors) == 6, len(tensors)
|
||||||
logging.info(f"{name}_{i}.shape: {s.shape}")
|
|
||||||
inputs[f"{name}_{i}"] = {N: "N"}
|
# (downsample_left, batch_size, key_dim)
|
||||||
outputs[f"new_{name}_{i}"] = {N: "N"}
|
name = f'cached_key_{i}'
|
||||||
input_names.append(f"{name}_{i}")
|
logging.info(f"{name}.shape: {tensors[0].shape}")
|
||||||
output_names.append(f"new_{name}_{i}")
|
inputs[f"{name}"] = {1: "N"}
|
||||||
|
outputs[f"new_{name}"] = {1: "N"}
|
||||||
|
input_names.append(f"{name}")
|
||||||
|
output_names.append(f"new_{name}")
|
||||||
|
|
||||||
|
# (1, batch_size, downsample_left, nonlin_attn_head_dim)
|
||||||
|
name = f'cached_nonlin_attn_{i}'
|
||||||
|
logging.info(f"{name}.shape: {tensors[1].shape}")
|
||||||
|
inputs[f"{name}"] = {1: "N"}
|
||||||
|
outputs[f"new_{name}"] = {1: "N"}
|
||||||
|
input_names.append(f"{name}")
|
||||||
|
output_names.append(f"new_{name}")
|
||||||
|
|
||||||
|
# (downsample_left, batch_size, value_dim)
|
||||||
|
name = f'cached_val1_{i}'
|
||||||
|
logging.info(f"{name}.shape: {tensors[2].shape}")
|
||||||
|
inputs[f"{name}"] = {1: "N"}
|
||||||
|
outputs[f"new_{name}"] = {1: "N"}
|
||||||
|
input_names.append(f"{name}")
|
||||||
|
output_names.append(f"new_{name}")
|
||||||
|
|
||||||
|
# (downsample_left, batch_size, value_dim)
|
||||||
|
name = f'cached_val2_{i}'
|
||||||
|
logging.info(f"{name}.shape: {tensors[3].shape}")
|
||||||
|
inputs[f"{name}"] = {1: "N"}
|
||||||
|
outputs[f"new_{name}"] = {1: "N"}
|
||||||
|
input_names.append(f"{name}")
|
||||||
|
output_names.append(f"new_{name}")
|
||||||
|
|
||||||
|
# (batch_size, embed_dim, conv_left_pad)
|
||||||
|
name = f'cached_conv1_{i}'
|
||||||
|
logging.info(f"{name}.shape: {tensors[4].shape}")
|
||||||
|
inputs[f"{name}"] = {0: "N"}
|
||||||
|
outputs[f"new_{name}"] = {0: "N"}
|
||||||
|
input_names.append(f"{name}")
|
||||||
|
output_names.append(f"new_{name}")
|
||||||
|
|
||||||
|
# (batch_size, embed_dim, conv_left_pad)
|
||||||
|
name = f'cached_conv2_{i}'
|
||||||
|
logging.info(f"{name}.shape: {tensors[5].shape}")
|
||||||
|
inputs[f"{name}"] = {0: "N"}
|
||||||
|
outputs[f"new_{name}"] = {0: "N"}
|
||||||
|
input_names.append(f"{name}")
|
||||||
|
output_names.append(f"new_{name}")
|
||||||
|
|
||||||
|
|
||||||
num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers))
|
num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers))
|
||||||
encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dim))
|
encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dim))
|
||||||
@ -376,30 +420,9 @@ def export_encoder_model_onnx(
|
|||||||
}
|
}
|
||||||
logging.info(f"meta_data: {meta_data}")
|
logging.info(f"meta_data: {meta_data}")
|
||||||
|
|
||||||
# (num_encoder_layers, left_context_len, 1, key_dim)
|
for i in range(len(init_state[:-2]) // 6):
|
||||||
cached_key = init_state[0:-2:6]
|
build_inputs_outputs(init_state[i*6:(i+1)*6], i)
|
||||||
|
|
||||||
# (num_encoder_layers, 1, 1, left_context_len, nonlin_attn_head_dim)
|
|
||||||
cached_nonlin_attn = init_state[1:-2:6]
|
|
||||||
|
|
||||||
# (num_encoder_layers, left_context_len, 1, attention_dim//2)
|
|
||||||
cached_val1 = init_state[2:-2:6]
|
|
||||||
|
|
||||||
# (num_encoder_layers, left_context_len, 1, attention_dim//2)
|
|
||||||
cached_val2 = init_state[3:-2:6]
|
|
||||||
|
|
||||||
# (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1)
|
|
||||||
cached_conv1 = init_state[4:-2:6]
|
|
||||||
|
|
||||||
# (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1)
|
|
||||||
cached_conv2 = init_state[5:-2:6]
|
|
||||||
|
|
||||||
build_inputs_outputs(cached_key, "cached_key", 1)
|
|
||||||
build_inputs_outputs(cached_nonlin_attn, "cached_nonlin_attn", 1)
|
|
||||||
build_inputs_outputs(cached_val1, "cached_val1", 1)
|
|
||||||
build_inputs_outputs(cached_val2, "cached_val2", 1)
|
|
||||||
build_inputs_outputs(cached_conv1, "cached_conv1", 0)
|
|
||||||
build_inputs_outputs(cached_conv2, "cached_conv2", 0)
|
|
||||||
|
|
||||||
# (batch_size, channels, left_pad, freq)
|
# (batch_size, channels, left_pad, freq)
|
||||||
embed_states = init_state[-2]
|
embed_states = init_state[-2]
|
||||||
@ -409,7 +432,7 @@ def export_encoder_model_onnx(
|
|||||||
outputs[f"new_{name}"] = {0: "N"}
|
outputs[f"new_{name}"] = {0: "N"}
|
||||||
input_names.append(f"{name}")
|
input_names.append(f"{name}")
|
||||||
output_names.append(f"new_{name}")
|
output_names.append(f"new_{name}")
|
||||||
|
|
||||||
# (batch_size,)
|
# (batch_size,)
|
||||||
processed_lens = init_state[-1]
|
processed_lens = init_state[-1]
|
||||||
name = f'processed_lens'
|
name = f'processed_lens'
|
||||||
@ -424,13 +447,11 @@ def export_encoder_model_onnx(
|
|||||||
logging.info(input_names)
|
logging.info(input_names)
|
||||||
logging.info(output_names)
|
logging.info(output_names)
|
||||||
|
|
||||||
encoder_model = torch.jit.trace(encoder_model, (x, init_state))
|
|
||||||
|
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
encoder_model,
|
encoder_model,
|
||||||
(x, init_state),
|
(x, init_state),
|
||||||
encoder_filename,
|
encoder_filename,
|
||||||
verbose=True,
|
verbose=False,
|
||||||
opset_version=opset_version,
|
opset_version=opset_version,
|
||||||
input_names=input_names,
|
input_names=input_names,
|
||||||
output_names=output_names,
|
output_names=output_names,
|
||||||
@ -653,7 +674,7 @@ def main():
|
|||||||
model.to("cpu")
|
model.to("cpu")
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
|
|
||||||
encoder = OnnxEncoder(
|
encoder = OnnxEncoder(
|
||||||
encoder=model.encoder,
|
encoder=model.encoder,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user