From 25effa8e9df77df74188d7d5a2cc5f6b373cff2a Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 12 Jun 2023 12:36:45 +0800 Subject: [PATCH] small fixes --- .../ASR/zipformer/export-onnx-streaming.py | 72 +++++++++--------- egs/librispeech/ASR/zipformer/export-onnx.py | 9 ++- .../zipformer/onnx_pretrained-streaming.py | 74 +++++++++++-------- 3 files changed, 90 insertions(+), 65 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index cfef9e6e5..356935657 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # # Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) +# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) """ This script exports a transducer model from PyTorch to ONNX. @@ -47,8 +48,12 @@ popd --decoder-dim 512 \ --joiner-dim 512 \ --causal True \ - --chunk-size "16,32,64,-1" \ - --left-context-frames "64,128,256,-1" + --chunk-size 16 \ + --left-context-frames 64 + +The --chunk-size in training is "16,32,64,-1", so we select one of them +(excluding -1) during streaming export. The same applies to `--left-context`, +whose value is "64,128,256,-1". It will generate the following 3 files inside $repo/exp: @@ -176,7 +181,9 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]): class OnnxEncoder(nn.Module): """A wrapper for Zipformer and the encoder_proj from the joiner""" - def __init__(self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear): + def __init__( + self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear + ): """ Args: encoder: @@ -335,7 +342,7 @@ def export_encoder_model_onnx( # The encoder_embed subsample features (T - 7) // 2 # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling T = decode_chunk_len + encoder_model.pad_length - + x = torch.rand(1, T, 80, dtype=torch.float32) init_state = encoder_model.get_init_states() num_encoders = len(encoder_model.encoder.encoder_dim) @@ -347,58 +354,58 @@ def export_encoder_model_onnx( outputs = {} output_names = ["encoder_out"] + def build_inputs_outputs(tensors, i): assert len(tensors) == 6, len(tensors) # (downsample_left, batch_size, key_dim) - name = f'cached_key_{i}' + name = f"cached_key_{i}" logging.info(f"{name}.shape: {tensors[0].shape}") - inputs[f"{name}"] = {1: "N"} + inputs[name] = {1: "N"} outputs[f"new_{name}"] = {1: "N"} - input_names.append(f"{name}") + input_names.append(name) output_names.append(f"new_{name}") # (1, batch_size, downsample_left, nonlin_attn_head_dim) - name = f'cached_nonlin_attn_{i}' + name = f"cached_nonlin_attn_{i}" logging.info(f"{name}.shape: {tensors[1].shape}") - inputs[f"{name}"] = {1: "N"} + inputs[name] = {1: "N"} outputs[f"new_{name}"] = {1: "N"} - input_names.append(f"{name}") + input_names.append(name) output_names.append(f"new_{name}") # (downsample_left, batch_size, value_dim) - name = f'cached_val1_{i}' + name = f"cached_val1_{i}" logging.info(f"{name}.shape: {tensors[2].shape}") - inputs[f"{name}"] = {1: "N"} + inputs[name] = {1: "N"} outputs[f"new_{name}"] = {1: "N"} - input_names.append(f"{name}") + input_names.append(name) output_names.append(f"new_{name}") # (downsample_left, batch_size, value_dim) - name = f'cached_val2_{i}' + name = f"cached_val2_{i}" logging.info(f"{name}.shape: {tensors[3].shape}") - inputs[f"{name}"] = {1: "N"} + inputs[name] = {1: "N"} outputs[f"new_{name}"] = {1: "N"} - input_names.append(f"{name}") + input_names.append(name) output_names.append(f"new_{name}") # (batch_size, embed_dim, conv_left_pad) - name = f'cached_conv1_{i}' + name = f"cached_conv1_{i}" logging.info(f"{name}.shape: {tensors[4].shape}") - inputs[f"{name}"] = {0: "N"} + inputs[name] = {0: "N"} outputs[f"new_{name}"] = {0: "N"} - input_names.append(f"{name}") + input_names.append(name) output_names.append(f"new_{name}") # (batch_size, embed_dim, conv_left_pad) - name = f'cached_conv2_{i}' + name = f"cached_conv2_{i}" logging.info(f"{name}.shape: {tensors[5].shape}") - inputs[f"{name}"] = {0: "N"} + inputs[name] = {0: "N"} outputs[f"new_{name}"] = {0: "N"} - input_names.append(f"{name}") + input_names.append(name) output_names.append(f"new_{name}") - num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers)) encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dim)) cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernel)) @@ -411,10 +418,10 @@ def export_encoder_model_onnx( num_heads = ",".join(map(str, encoder_model.encoder.num_heads)) meta_data = { - "model_type": "zipformer", + "model_type": "zipformer2", "version": "1", "model_author": "k2-fsa", - "comment": "zipformer", + "comment": "streaming zipformer2", "decode_chunk_len": str(decode_chunk_len), # 32 "T": str(T), # 32+7+2*3=45 "num_encoder_layers": num_encoder_layers, @@ -428,25 +435,24 @@ def export_encoder_model_onnx( logging.info(f"meta_data: {meta_data}") for i in range(len(init_state[:-2]) // 6): - build_inputs_outputs(init_state[i*6:(i+1)*6], i) - + build_inputs_outputs(init_state[i * 6 : (i + 1) * 6], i) # (batch_size, channels, left_pad, freq) embed_states = init_state[-2] - name = f'embed_states' + name = "embed_states" logging.info(f"{name}.shape: {embed_states.shape}") - inputs[f"{name}"] = {0: "N"} + inputs[name] = {0: "N"} outputs[f"new_{name}"] = {0: "N"} - input_names.append(f"{name}") + input_names.append(name) output_names.append(f"new_{name}") # (batch_size,) processed_lens = init_state[-1] - name = f'processed_lens' + name = "processed_lens" logging.info(f"{name}.shape: {processed_lens.shape}") - inputs[f"{name}"] = {0: "N"} + inputs[name] = {0: "N"} outputs[f"new_{name}"] = {0: "N"} - input_names.append(f"{name}") + input_names.append(name) output_names.append(f"new_{name}") logging.info(inputs) diff --git a/egs/librispeech/ASR/zipformer/export-onnx.py b/egs/librispeech/ASR/zipformer/export-onnx.py index 42c9187d9..490e7c2e9 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx.py +++ b/egs/librispeech/ASR/zipformer/export-onnx.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # # Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) +# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) """ This script exports a transducer model from PyTorch to ONNX. @@ -178,7 +179,9 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]): class OnnxEncoder(nn.Module): """A wrapper for Zipformer and the encoder_proj from the joiner""" - def __init__(self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear): + def __init__( + self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear + ): """ Args: encoder: @@ -315,10 +318,10 @@ def export_encoder_model_onnx( ) meta_data = { - "model_type": "zipformer", + "model_type": "zipformer2", "version": "1", "model_author": "k2-fsa", - "comment": "zipformer", + "comment": "non-streaming zipformer2", } logging.info(f"meta_data: {meta_data}") diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py index b434a2e76..273f883df 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 # Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) """ This script loads ONNX models exported by ./export-onnx-streaming.py @@ -47,8 +48,8 @@ popd --decoder-dim 512 \ --joiner-dim 512 \ --causal True \ - --chunk-size "16,32,64,-1" \ - --left-context-frames "64,128,256,-1" + --chunk-size 16 \ + --left-context-frames 64 It will generate the following 3 files inside $repo/exp: @@ -154,7 +155,7 @@ class OnnxModel: logging.info(f"encoder_meta={encoder_meta}") model_type = encoder_meta["model_type"] - assert model_type == "zipformer", model_type + assert model_type == "zipformer2", model_type decode_chunk_len = int(encoder_meta["decode_chunk_len"]) T = int(encoder_meta["T"]) @@ -200,16 +201,31 @@ class OnnxModel: conv_left_pad = cnn_module_kernels[i] // 2 for layer in range(num_layers): - cached_key = torch.zeros(left_context_len[i], batch_size, key_dim) - cached_nonlin_attn = torch.zeros(1, batch_size, left_context_len[i], nonlin_attn_head_dim) - cached_val1 = torch.zeros(left_context_len[i], batch_size, value_dim) - cached_val2 = torch.zeros(left_context_len[i], batch_size, value_dim) - cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad) - cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad) - self.states += [cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2] - embed_states = torch.zeros(batch_size, 128, 3, 19) + cached_key = torch.zeros( + left_context_len[i], batch_size, key_dim + ).numpy() + cached_nonlin_attn = torch.zeros( + 1, batch_size, left_context_len[i], nonlin_attn_head_dim + ).numpy() + cached_val1 = torch.zeros( + left_context_len[i], batch_size, value_dim + ).numpy() + cached_val2 = torch.zeros( + left_context_len[i], batch_size, value_dim + ).numpy() + cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy() + cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy() + self.states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + embed_states = torch.zeros(batch_size, 128, 3, 19).numpy() self.states.append(embed_states) - processed_lens = torch.zeros(batch_size, dtype=torch.int64) + processed_lens = torch.zeros(batch_size, dtype=torch.int64).numpy() self.states.append(processed_lens) self.num_encoders = num_encoders @@ -252,48 +268,48 @@ class OnnxModel: assert len(tensors) == 6, len(tensors) # (downsample_left, batch_size, key_dim) - name = f'cached_key_{i}' - encoder_input[f"{name}"] = tensors[0].numpy() if isinstance(tensors[0], torch.Tensor) else tensors[0] + name = f"cached_key_{i}" + encoder_input[name] = tensors[0] encoder_output.append(f"new_{name}") # (1, batch_size, downsample_left, nonlin_attn_head_dim) - name = f'cached_nonlin_attn_{i}' - encoder_input[f"{name}"] = tensors[1].numpy() if isinstance(tensors[1], torch.Tensor) else tensors[1] + name = f"cached_nonlin_attn_{i}" + encoder_input[name] = tensors[1] encoder_output.append(f"new_{name}") # (downsample_left, batch_size, value_dim) - name = f'cached_val1_{i}' - encoder_input[f"{name}"] = tensors[2].numpy() if isinstance(tensors[2], torch.Tensor) else tensors[2] + name = f"cached_val1_{i}" + encoder_input[name] = tensors[2] encoder_output.append(f"new_{name}") # (downsample_left, batch_size, value_dim) - name = f'cached_val2_{i}' - encoder_input[f"{name}"] = tensors[3].numpy() if isinstance(tensors[3], torch.Tensor) else tensors[3] + name = f"cached_val2_{i}" + encoder_input[name] = tensors[3] encoder_output.append(f"new_{name}") # (batch_size, embed_dim, conv_left_pad) - name = f'cached_conv1_{i}' - encoder_input[f"{name}"] = tensors[4].numpy() if isinstance(tensors[4], torch.Tensor) else tensors[4] + name = f"cached_conv1_{i}" + encoder_input[name] = tensors[4] encoder_output.append(f"new_{name}") # (batch_size, embed_dim, conv_left_pad) - name = f'cached_conv2_{i}' - encoder_input[f"{name}"] = tensors[5].numpy() if isinstance(tensors[5], torch.Tensor) else tensors[5] + name = f"cached_conv2_{i}" + encoder_input[name] = tensors[5] encoder_output.append(f"new_{name}") for i in range(len(self.states[:-2]) // 6): - build_inputs_outputs(self.states[i*6:(i+1)*6], i) + build_inputs_outputs(self.states[i * 6 : (i + 1) * 6], i) # (batch_size, channels, left_pad, freq) - name = f'embed_states' + name = "embed_states" embed_states = self.states[-2] - encoder_input[f"{name}"] = embed_states.numpy() if isinstance(embed_states, torch.Tensor) else embed_states + encoder_input[name] = embed_states encoder_output.append(f"new_{name}") # (batch_size,) - name = f'processed_lens' + name = "processed_lens" processed_lens = self.states[-1] - encoder_input[f"{name}"] = processed_lens.numpy() if isinstance(processed_lens, torch.Tensor) else processed_lens + encoder_input[name] = processed_lens encoder_output.append(f"new_{name}") return encoder_input, encoder_output