small fixes

This commit is contained in:
danqing fu 2023-06-11 21:39:24 +08:00
parent d932ed0928
commit 15c7035dad
2 changed files with 33 additions and 16 deletions

View File

@ -414,7 +414,8 @@ def export_encoder_model_onnx(
"model_type": "zipformer", "model_type": "zipformer",
"version": "1", "version": "1",
"model_author": "k2-fsa", "model_author": "k2-fsa",
"f": str(decode_chunk_len), # 32 "comment": "zipformer",
"decode_chunk_len": str(decode_chunk_len), # 32
"T": str(T), # 32+7+2*3=45 "T": str(T), # 32+7+2*3=45
"num_encoder_layers": num_encoder_layers, "num_encoder_layers": num_encoder_layers,
"encoder_dims": encoder_dims, "encoder_dims": encoder_dims,
@ -469,14 +470,6 @@ def export_encoder_model_onnx(
}, },
) )
meta_data = {
"model_type": "zipformer",
"version": "1",
"model_author": "k2-fsa",
"comment": "zipformer",
}
logging.info(f"meta_data: {meta_data}")
add_meta_data(filename=encoder_filename, meta_data=meta_data) add_meta_data(filename=encoder_filename, meta_data=meta_data)

View File

@ -151,6 +151,7 @@ class OnnxModel:
def init_encoder_states(self, batch_size: int = 1): def init_encoder_states(self, batch_size: int = 1):
encoder_meta = self.encoder.get_modelmeta().custom_metadata_map encoder_meta = self.encoder.get_modelmeta().custom_metadata_map
logging.info(f"encoder_meta={encoder_meta}")
model_type = encoder_meta["model_type"] model_type = encoder_meta["model_type"]
assert model_type == "zipformer", model_type assert model_type == "zipformer", model_type
@ -191,19 +192,25 @@ class OnnxModel:
self.states = [] self.states = []
for i in range(num_encoders): for i in range(num_encoders):
num_layers = num_encoder_layers[i]
key_dim = query_head_dims[i] * num_heads[i] key_dim = query_head_dims[i] * num_heads[i]
embed_dim = encoder_dims[i] embed_dim = encoder_dims[i]
nonlin_attn_head_dim = 3 * embed_dim // 4 nonlin_attn_head_dim = 3 * embed_dim // 4
value_dim = value_head_dims[i] * num_heads[i] value_dim = value_head_dims[i] * num_heads[i]
conv_left_pad = cnn_module_kernels[i] // 2 conv_left_pad = cnn_module_kernels[i] // 2
cached_key = torch.zeros(left_context_len, batch_size, key_dim) for layer in range(num_layers):
cached_nonlin_attn = torch.zeros(1, batch_size, left_context_len, nonlin_attn_head_dim) cached_key = torch.zeros(left_context_len[i], batch_size, key_dim)
cached_val1 = torch.zeros(left_context_len, batch_size, value_dim) cached_nonlin_attn = torch.zeros(1, batch_size, left_context_len[i], nonlin_attn_head_dim)
cached_val2 = torch.zeros(left_context_len, batch_size, value_dim) cached_val1 = torch.zeros(left_context_len[i], batch_size, value_dim)
cached_val2 = torch.zeros(left_context_len[i], batch_size, value_dim)
cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad) cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad)
cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad) cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad)
self.states += [cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2] self.states += [cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2]
embed_states = torch.zeros(batch_size, 128, 3, 19)
self.states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int32)
self.states.append(processed_lens)
self.num_encoders = num_encoders self.num_encoders = num_encoders
@ -277,6 +284,19 @@ class OnnxModel:
for i in range(len(self.states[:-2]) // 6): for i in range(len(self.states[:-2]) // 6):
build_inputs_outputs(self.states[i*6:(i+1)*6], i) build_inputs_outputs(self.states[i*6:(i+1)*6], i)
# (batch_size, channels, left_pad, freq)
name = f'embed_states'
embed_states = self.states[-2]
encoder_input[f"{name}"] = embed_states.numpy() if isinstance(embed_states, torch.Tensor) else embed_states
encoder_output.append(f"new_{name}")
# (batch_size,)
name = f'processed_lens'
processed_lens = self.states[-1]
encoder_input[f"{name}"] = processed_lens.numpy() if isinstance(processed_lens, torch.Tensor) else processed_lens
encoder_output.append(f"new_{name}")
logging.info(f"encoder_output_len={len(encoder_output)}")
return encoder_input, encoder_output return encoder_input, encoder_output
def _update_states(self, states: List[np.ndarray]): def _update_states(self, states: List[np.ndarray]):
@ -292,7 +312,11 @@ class OnnxModel:
T' is usually equal to ((T-7)//2+1)//2 T' is usually equal to ((T-7)//2+1)//2
""" """
encoder_input, encoder_output_names = self._build_encoder_input_output(x) encoder_input, encoder_output_names = self._build_encoder_input_output(x)
# logging.info(encoder_input.keys())
# logging.info(encoder_output_names)
out = self.encoder.run(encoder_output_names, encoder_input) out = self.encoder.run(encoder_output_names, encoder_input)
len = out.pop(1)
out.append(len.astype(np.int32))
self._update_states(out[1:]) self._update_states(out[1:])