diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index 43d775cff..3dbc355d1 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -269,7 +269,7 @@ class OnnxEncoder(nn.Module): processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) states.append(processed_lens) - return states + return states class OnnxDecoder(nn.Module): @@ -347,13 +347,57 @@ def export_encoder_model_onnx( outputs = {} output_names = ["encoder_out"] - def build_inputs_outputs(tensors, name, N): - for i, s in enumerate(tensors): - logging.info(f"{name}_{i}.shape: {s.shape}") - inputs[f"{name}_{i}"] = {N: "N"} - outputs[f"new_{name}_{i}"] = {N: "N"} - input_names.append(f"{name}_{i}") - output_names.append(f"new_{name}_{i}") + def build_inputs_outputs(tensors, i): + assert len(tensors) == 6, len(tensors) + + # (downsample_left, batch_size, key_dim) + name = f'cached_key_{i}' + logging.info(f"{name}.shape: {tensors[0].shape}") + inputs[f"{name}"] = {1: "N"} + outputs[f"new_{name}"] = {1: "N"} + input_names.append(f"{name}") + output_names.append(f"new_{name}") + + # (1, batch_size, downsample_left, nonlin_attn_head_dim) + name = f'cached_nonlin_attn_{i}' + logging.info(f"{name}.shape: {tensors[1].shape}") + inputs[f"{name}"] = {1: "N"} + outputs[f"new_{name}"] = {1: "N"} + input_names.append(f"{name}") + output_names.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f'cached_val1_{i}' + logging.info(f"{name}.shape: {tensors[2].shape}") + inputs[f"{name}"] = {1: "N"} + outputs[f"new_{name}"] = {1: "N"} + input_names.append(f"{name}") + output_names.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f'cached_val2_{i}' + logging.info(f"{name}.shape: {tensors[3].shape}") + inputs[f"{name}"] = {1: "N"} + outputs[f"new_{name}"] = {1: "N"} + input_names.append(f"{name}") + output_names.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f'cached_conv1_{i}' + logging.info(f"{name}.shape: {tensors[4].shape}") + inputs[f"{name}"] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(f"{name}") + output_names.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f'cached_conv2_{i}' + logging.info(f"{name}.shape: {tensors[5].shape}") + inputs[f"{name}"] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(f"{name}") + output_names.append(f"new_{name}") + num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers)) encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dim)) @@ -376,30 +420,9 @@ def export_encoder_model_onnx( } logging.info(f"meta_data: {meta_data}") - # (num_encoder_layers, left_context_len, 1, key_dim) - cached_key = init_state[0:-2:6] + for i in range(len(init_state[:-2]) // 6): + build_inputs_outputs(init_state[i*6:(i+1)*6], i) - # (num_encoder_layers, 1, 1, left_context_len, nonlin_attn_head_dim) - cached_nonlin_attn = init_state[1:-2:6] - - # (num_encoder_layers, left_context_len, 1, attention_dim//2) - cached_val1 = init_state[2:-2:6] - - # (num_encoder_layers, left_context_len, 1, attention_dim//2) - cached_val2 = init_state[3:-2:6] - - # (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1) - cached_conv1 = init_state[4:-2:6] - - # (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1) - cached_conv2 = init_state[5:-2:6] - - build_inputs_outputs(cached_key, "cached_key", 1) - build_inputs_outputs(cached_nonlin_attn, "cached_nonlin_attn", 1) - build_inputs_outputs(cached_val1, "cached_val1", 1) - build_inputs_outputs(cached_val2, "cached_val2", 1) - build_inputs_outputs(cached_conv1, "cached_conv1", 0) - build_inputs_outputs(cached_conv2, "cached_conv2", 0) # (batch_size, channels, left_pad, freq) embed_states = init_state[-2] @@ -409,7 +432,7 @@ def export_encoder_model_onnx( outputs[f"new_{name}"] = {0: "N"} input_names.append(f"{name}") output_names.append(f"new_{name}") - + # (batch_size,) processed_lens = init_state[-1] name = f'processed_lens' @@ -424,13 +447,11 @@ def export_encoder_model_onnx( logging.info(input_names) logging.info(output_names) - encoder_model = torch.jit.trace(encoder_model, (x, init_state)) - torch.onnx.export( encoder_model, (x, init_state), encoder_filename, - verbose=True, + verbose=False, opset_version=opset_version, input_names=input_names, output_names=output_names, @@ -653,7 +674,7 @@ def main(): model.to("cpu") model.eval() - convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) + convert_scaled_to_non_scaled(model, inplace=True) encoder = OnnxEncoder( encoder=model.encoder,