mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
small fixes
This commit is contained in:
parent
d932ed0928
commit
15c7035dad
@ -414,7 +414,8 @@ def export_encoder_model_onnx(
|
|||||||
"model_type": "zipformer",
|
"model_type": "zipformer",
|
||||||
"version": "1",
|
"version": "1",
|
||||||
"model_author": "k2-fsa",
|
"model_author": "k2-fsa",
|
||||||
"f": str(decode_chunk_len), # 32
|
"comment": "zipformer",
|
||||||
|
"decode_chunk_len": str(decode_chunk_len), # 32
|
||||||
"T": str(T), # 32+7+2*3=45
|
"T": str(T), # 32+7+2*3=45
|
||||||
"num_encoder_layers": num_encoder_layers,
|
"num_encoder_layers": num_encoder_layers,
|
||||||
"encoder_dims": encoder_dims,
|
"encoder_dims": encoder_dims,
|
||||||
@ -469,14 +470,6 @@ def export_encoder_model_onnx(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
meta_data = {
|
|
||||||
"model_type": "zipformer",
|
|
||||||
"version": "1",
|
|
||||||
"model_author": "k2-fsa",
|
|
||||||
"comment": "zipformer",
|
|
||||||
}
|
|
||||||
logging.info(f"meta_data: {meta_data}")
|
|
||||||
|
|
||||||
add_meta_data(filename=encoder_filename, meta_data=meta_data)
|
add_meta_data(filename=encoder_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -151,6 +151,7 @@ class OnnxModel:
|
|||||||
|
|
||||||
def init_encoder_states(self, batch_size: int = 1):
|
def init_encoder_states(self, batch_size: int = 1):
|
||||||
encoder_meta = self.encoder.get_modelmeta().custom_metadata_map
|
encoder_meta = self.encoder.get_modelmeta().custom_metadata_map
|
||||||
|
logging.info(f"encoder_meta={encoder_meta}")
|
||||||
|
|
||||||
model_type = encoder_meta["model_type"]
|
model_type = encoder_meta["model_type"]
|
||||||
assert model_type == "zipformer", model_type
|
assert model_type == "zipformer", model_type
|
||||||
@ -191,19 +192,25 @@ class OnnxModel:
|
|||||||
|
|
||||||
self.states = []
|
self.states = []
|
||||||
for i in range(num_encoders):
|
for i in range(num_encoders):
|
||||||
|
num_layers = num_encoder_layers[i]
|
||||||
key_dim = query_head_dims[i] * num_heads[i]
|
key_dim = query_head_dims[i] * num_heads[i]
|
||||||
embed_dim = encoder_dims[i]
|
embed_dim = encoder_dims[i]
|
||||||
nonlin_attn_head_dim = 3 * embed_dim // 4
|
nonlin_attn_head_dim = 3 * embed_dim // 4
|
||||||
value_dim = value_head_dims[i] * num_heads[i]
|
value_dim = value_head_dims[i] * num_heads[i]
|
||||||
conv_left_pad = cnn_module_kernels[i] // 2
|
conv_left_pad = cnn_module_kernels[i] // 2
|
||||||
|
|
||||||
cached_key = torch.zeros(left_context_len, batch_size, key_dim)
|
for layer in range(num_layers):
|
||||||
cached_nonlin_attn = torch.zeros(1, batch_size, left_context_len, nonlin_attn_head_dim)
|
cached_key = torch.zeros(left_context_len[i], batch_size, key_dim)
|
||||||
cached_val1 = torch.zeros(left_context_len, batch_size, value_dim)
|
cached_nonlin_attn = torch.zeros(1, batch_size, left_context_len[i], nonlin_attn_head_dim)
|
||||||
cached_val2 = torch.zeros(left_context_len, batch_size, value_dim)
|
cached_val1 = torch.zeros(left_context_len[i], batch_size, value_dim)
|
||||||
|
cached_val2 = torch.zeros(left_context_len[i], batch_size, value_dim)
|
||||||
cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad)
|
cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad)
|
||||||
cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad)
|
cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad)
|
||||||
self.states += [cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2]
|
self.states += [cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2]
|
||||||
|
embed_states = torch.zeros(batch_size, 128, 3, 19)
|
||||||
|
self.states.append(embed_states)
|
||||||
|
processed_lens = torch.zeros(batch_size, dtype=torch.int32)
|
||||||
|
self.states.append(processed_lens)
|
||||||
|
|
||||||
self.num_encoders = num_encoders
|
self.num_encoders = num_encoders
|
||||||
|
|
||||||
@ -277,6 +284,19 @@ class OnnxModel:
|
|||||||
for i in range(len(self.states[:-2]) // 6):
|
for i in range(len(self.states[:-2]) // 6):
|
||||||
build_inputs_outputs(self.states[i*6:(i+1)*6], i)
|
build_inputs_outputs(self.states[i*6:(i+1)*6], i)
|
||||||
|
|
||||||
|
# (batch_size, channels, left_pad, freq)
|
||||||
|
name = f'embed_states'
|
||||||
|
embed_states = self.states[-2]
|
||||||
|
encoder_input[f"{name}"] = embed_states.numpy() if isinstance(embed_states, torch.Tensor) else embed_states
|
||||||
|
encoder_output.append(f"new_{name}")
|
||||||
|
|
||||||
|
# (batch_size,)
|
||||||
|
name = f'processed_lens'
|
||||||
|
processed_lens = self.states[-1]
|
||||||
|
encoder_input[f"{name}"] = processed_lens.numpy() if isinstance(processed_lens, torch.Tensor) else processed_lens
|
||||||
|
encoder_output.append(f"new_{name}")
|
||||||
|
|
||||||
|
logging.info(f"encoder_output_len={len(encoder_output)}")
|
||||||
return encoder_input, encoder_output
|
return encoder_input, encoder_output
|
||||||
|
|
||||||
def _update_states(self, states: List[np.ndarray]):
|
def _update_states(self, states: List[np.ndarray]):
|
||||||
@ -292,7 +312,11 @@ class OnnxModel:
|
|||||||
T' is usually equal to ((T-7)//2+1)//2
|
T' is usually equal to ((T-7)//2+1)//2
|
||||||
"""
|
"""
|
||||||
encoder_input, encoder_output_names = self._build_encoder_input_output(x)
|
encoder_input, encoder_output_names = self._build_encoder_input_output(x)
|
||||||
|
# logging.info(encoder_input.keys())
|
||||||
|
# logging.info(encoder_output_names)
|
||||||
out = self.encoder.run(encoder_output_names, encoder_input)
|
out = self.encoder.run(encoder_output_names, encoder_input)
|
||||||
|
len = out.pop(1)
|
||||||
|
out.append(len.astype(np.int32))
|
||||||
|
|
||||||
self._update_states(out[1:])
|
self._update_states(out[1:])
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user