diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index a86021b78..65a8d7264 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -414,7 +414,8 @@ def export_encoder_model_onnx( "model_type": "zipformer", "version": "1", "model_author": "k2-fsa", - "f": str(decode_chunk_len), # 32 + "comment": "zipformer", + "decode_chunk_len": str(decode_chunk_len), # 32 "T": str(T), # 32+7+2*3=45 "num_encoder_layers": num_encoder_layers, "encoder_dims": encoder_dims, @@ -469,14 +470,6 @@ def export_encoder_model_onnx( }, ) - meta_data = { - "model_type": "zipformer", - "version": "1", - "model_author": "k2-fsa", - "comment": "zipformer", - } - logging.info(f"meta_data: {meta_data}") - add_meta_data(filename=encoder_filename, meta_data=meta_data) diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py index ec74001f6..f74036745 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py @@ -151,6 +151,7 @@ class OnnxModel: def init_encoder_states(self, batch_size: int = 1): encoder_meta = self.encoder.get_modelmeta().custom_metadata_map + logging.info(f"encoder_meta={encoder_meta}") model_type = encoder_meta["model_type"] assert model_type == "zipformer", model_type @@ -191,19 +192,25 @@ class OnnxModel: self.states = [] for i in range(num_encoders): + num_layers = num_encoder_layers[i] key_dim = query_head_dims[i] * num_heads[i] embed_dim = encoder_dims[i] nonlin_attn_head_dim = 3 * embed_dim // 4 value_dim = value_head_dims[i] * num_heads[i] conv_left_pad = cnn_module_kernels[i] // 2 - cached_key = torch.zeros(left_context_len, batch_size, key_dim) - cached_nonlin_attn = torch.zeros(1, batch_size, left_context_len, nonlin_attn_head_dim) - cached_val1 = torch.zeros(left_context_len, batch_size, value_dim) - cached_val2 = torch.zeros(left_context_len, batch_size, value_dim) - cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad) - cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad) - self.states += [cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2] + for layer in range(num_layers): + cached_key = torch.zeros(left_context_len[i], batch_size, key_dim) + cached_nonlin_attn = torch.zeros(1, batch_size, left_context_len[i], nonlin_attn_head_dim) + cached_val1 = torch.zeros(left_context_len[i], batch_size, value_dim) + cached_val2 = torch.zeros(left_context_len[i], batch_size, value_dim) + cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad) + cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad) + self.states += [cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2] + embed_states = torch.zeros(batch_size, 128, 3, 19) + self.states.append(embed_states) + processed_lens = torch.zeros(batch_size, dtype=torch.int32) + self.states.append(processed_lens) self.num_encoders = num_encoders @@ -277,6 +284,19 @@ class OnnxModel: for i in range(len(self.states[:-2]) // 6): build_inputs_outputs(self.states[i*6:(i+1)*6], i) + # (batch_size, channels, left_pad, freq) + name = f'embed_states' + embed_states = self.states[-2] + encoder_input[f"{name}"] = embed_states.numpy() if isinstance(embed_states, torch.Tensor) else embed_states + encoder_output.append(f"new_{name}") + + # (batch_size,) + name = f'processed_lens' + processed_lens = self.states[-1] + encoder_input[f"{name}"] = processed_lens.numpy() if isinstance(processed_lens, torch.Tensor) else processed_lens + encoder_output.append(f"new_{name}") + + logging.info(f"encoder_output_len={len(encoder_output)}") return encoder_input, encoder_output def _update_states(self, states: List[np.ndarray]): @@ -292,7 +312,11 @@ class OnnxModel: T' is usually equal to ((T-7)//2+1)//2 """ encoder_input, encoder_output_names = self._build_encoder_input_output(x) + # logging.info(encoder_input.keys()) + # logging.info(encoder_output_names) out = self.encoder.run(encoder_output_names, encoder_input) + len = out.pop(1) + out.append(len.astype(np.int32)) self._update_states(out[1:])