Merge pull request #6 from csukuangfj/small-fixes-online-zipformer2

Small fixes
This commit is contained in:
danfu 2023-06-12 13:19:24 +08:00 committed by GitHub
commit 31ba3418cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 90 additions and 65 deletions

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
"""
This script exports a transducer model from PyTorch to ONNX.
@ -47,8 +48,12 @@ popd
--decoder-dim 512 \
--joiner-dim 512 \
--causal True \
--chunk-size "16,32,64,-1" \
--left-context-frames "64,128,256,-1"
--chunk-size 16 \
--left-context-frames 64
The --chunk-size in training is "16,32,64,-1", so we select one of them
(excluding -1) during streaming export. The same applies to `--left-context`,
whose value is "64,128,256,-1".
It will generate the following 3 files inside $repo/exp:
@ -176,7 +181,9 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]):
class OnnxEncoder(nn.Module):
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
def __init__(self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear):
def __init__(
self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear
):
"""
Args:
encoder:
@ -335,7 +342,7 @@ def export_encoder_model_onnx(
# The encoder_embed subsample features (T - 7) // 2
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
T = decode_chunk_len + encoder_model.pad_length
x = torch.rand(1, T, 80, dtype=torch.float32)
init_state = encoder_model.get_init_states()
num_encoders = len(encoder_model.encoder.encoder_dim)
@ -347,58 +354,58 @@ def export_encoder_model_onnx(
outputs = {}
output_names = ["encoder_out"]
def build_inputs_outputs(tensors, i):
assert len(tensors) == 6, len(tensors)
# (downsample_left, batch_size, key_dim)
name = f'cached_key_{i}'
name = f"cached_key_{i}"
logging.info(f"{name}.shape: {tensors[0].shape}")
inputs[f"{name}"] = {1: "N"}
inputs[name] = {1: "N"}
outputs[f"new_{name}"] = {1: "N"}
input_names.append(f"{name}")
input_names.append(name)
output_names.append(f"new_{name}")
# (1, batch_size, downsample_left, nonlin_attn_head_dim)
name = f'cached_nonlin_attn_{i}'
name = f"cached_nonlin_attn_{i}"
logging.info(f"{name}.shape: {tensors[1].shape}")
inputs[f"{name}"] = {1: "N"}
inputs[name] = {1: "N"}
outputs[f"new_{name}"] = {1: "N"}
input_names.append(f"{name}")
input_names.append(name)
output_names.append(f"new_{name}")
# (downsample_left, batch_size, value_dim)
name = f'cached_val1_{i}'
name = f"cached_val1_{i}"
logging.info(f"{name}.shape: {tensors[2].shape}")
inputs[f"{name}"] = {1: "N"}
inputs[name] = {1: "N"}
outputs[f"new_{name}"] = {1: "N"}
input_names.append(f"{name}")
input_names.append(name)
output_names.append(f"new_{name}")
# (downsample_left, batch_size, value_dim)
name = f'cached_val2_{i}'
name = f"cached_val2_{i}"
logging.info(f"{name}.shape: {tensors[3].shape}")
inputs[f"{name}"] = {1: "N"}
inputs[name] = {1: "N"}
outputs[f"new_{name}"] = {1: "N"}
input_names.append(f"{name}")
input_names.append(name)
output_names.append(f"new_{name}")
# (batch_size, embed_dim, conv_left_pad)
name = f'cached_conv1_{i}'
name = f"cached_conv1_{i}"
logging.info(f"{name}.shape: {tensors[4].shape}")
inputs[f"{name}"] = {0: "N"}
inputs[name] = {0: "N"}
outputs[f"new_{name}"] = {0: "N"}
input_names.append(f"{name}")
input_names.append(name)
output_names.append(f"new_{name}")
# (batch_size, embed_dim, conv_left_pad)
name = f'cached_conv2_{i}'
name = f"cached_conv2_{i}"
logging.info(f"{name}.shape: {tensors[5].shape}")
inputs[f"{name}"] = {0: "N"}
inputs[name] = {0: "N"}
outputs[f"new_{name}"] = {0: "N"}
input_names.append(f"{name}")
input_names.append(name)
output_names.append(f"new_{name}")
num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers))
encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dim))
cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernel))
@ -411,10 +418,10 @@ def export_encoder_model_onnx(
num_heads = ",".join(map(str, encoder_model.encoder.num_heads))
meta_data = {
"model_type": "zipformer",
"model_type": "zipformer2",
"version": "1",
"model_author": "k2-fsa",
"comment": "zipformer",
"comment": "streaming zipformer2",
"decode_chunk_len": str(decode_chunk_len), # 32
"T": str(T), # 32+7+2*3=45
"num_encoder_layers": num_encoder_layers,
@ -428,25 +435,24 @@ def export_encoder_model_onnx(
logging.info(f"meta_data: {meta_data}")
for i in range(len(init_state[:-2]) // 6):
build_inputs_outputs(init_state[i*6:(i+1)*6], i)
build_inputs_outputs(init_state[i * 6 : (i + 1) * 6], i)
# (batch_size, channels, left_pad, freq)
embed_states = init_state[-2]
name = f'embed_states'
name = "embed_states"
logging.info(f"{name}.shape: {embed_states.shape}")
inputs[f"{name}"] = {0: "N"}
inputs[name] = {0: "N"}
outputs[f"new_{name}"] = {0: "N"}
input_names.append(f"{name}")
input_names.append(name)
output_names.append(f"new_{name}")
# (batch_size,)
processed_lens = init_state[-1]
name = f'processed_lens'
name = "processed_lens"
logging.info(f"{name}.shape: {processed_lens.shape}")
inputs[f"{name}"] = {0: "N"}
inputs[name] = {0: "N"}
outputs[f"new_{name}"] = {0: "N"}
input_names.append(f"{name}")
input_names.append(name)
output_names.append(f"new_{name}")
logging.info(inputs)

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
"""
This script exports a transducer model from PyTorch to ONNX.
@ -178,7 +179,9 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]):
class OnnxEncoder(nn.Module):
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
def __init__(self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear):
def __init__(
self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear
):
"""
Args:
encoder:
@ -315,10 +318,10 @@ def export_encoder_model_onnx(
)
meta_data = {
"model_type": "zipformer",
"model_type": "zipformer2",
"version": "1",
"model_author": "k2-fsa",
"comment": "zipformer",
"comment": "non-streaming zipformer2",
}
logging.info(f"meta_data: {meta_data}")

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
"""
This script loads ONNX models exported by ./export-onnx-streaming.py
@ -47,8 +48,8 @@ popd
--decoder-dim 512 \
--joiner-dim 512 \
--causal True \
--chunk-size "16,32,64,-1" \
--left-context-frames "64,128,256,-1"
--chunk-size 16 \
--left-context-frames 64
It will generate the following 3 files inside $repo/exp:
@ -154,7 +155,7 @@ class OnnxModel:
logging.info(f"encoder_meta={encoder_meta}")
model_type = encoder_meta["model_type"]
assert model_type == "zipformer", model_type
assert model_type == "zipformer2", model_type
decode_chunk_len = int(encoder_meta["decode_chunk_len"])
T = int(encoder_meta["T"])
@ -200,16 +201,31 @@ class OnnxModel:
conv_left_pad = cnn_module_kernels[i] // 2
for layer in range(num_layers):
cached_key = torch.zeros(left_context_len[i], batch_size, key_dim)
cached_nonlin_attn = torch.zeros(1, batch_size, left_context_len[i], nonlin_attn_head_dim)
cached_val1 = torch.zeros(left_context_len[i], batch_size, value_dim)
cached_val2 = torch.zeros(left_context_len[i], batch_size, value_dim)
cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad)
cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad)
self.states += [cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2]
embed_states = torch.zeros(batch_size, 128, 3, 19)
cached_key = torch.zeros(
left_context_len[i], batch_size, key_dim
).numpy()
cached_nonlin_attn = torch.zeros(
1, batch_size, left_context_len[i], nonlin_attn_head_dim
).numpy()
cached_val1 = torch.zeros(
left_context_len[i], batch_size, value_dim
).numpy()
cached_val2 = torch.zeros(
left_context_len[i], batch_size, value_dim
).numpy()
cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy()
cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy()
self.states += [
cached_key,
cached_nonlin_attn,
cached_val1,
cached_val2,
cached_conv1,
cached_conv2,
]
embed_states = torch.zeros(batch_size, 128, 3, 19).numpy()
self.states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int64)
processed_lens = torch.zeros(batch_size, dtype=torch.int64).numpy()
self.states.append(processed_lens)
self.num_encoders = num_encoders
@ -252,48 +268,48 @@ class OnnxModel:
assert len(tensors) == 6, len(tensors)
# (downsample_left, batch_size, key_dim)
name = f'cached_key_{i}'
encoder_input[f"{name}"] = tensors[0].numpy() if isinstance(tensors[0], torch.Tensor) else tensors[0]
name = f"cached_key_{i}"
encoder_input[name] = tensors[0]
encoder_output.append(f"new_{name}")
# (1, batch_size, downsample_left, nonlin_attn_head_dim)
name = f'cached_nonlin_attn_{i}'
encoder_input[f"{name}"] = tensors[1].numpy() if isinstance(tensors[1], torch.Tensor) else tensors[1]
name = f"cached_nonlin_attn_{i}"
encoder_input[name] = tensors[1]
encoder_output.append(f"new_{name}")
# (downsample_left, batch_size, value_dim)
name = f'cached_val1_{i}'
encoder_input[f"{name}"] = tensors[2].numpy() if isinstance(tensors[2], torch.Tensor) else tensors[2]
name = f"cached_val1_{i}"
encoder_input[name] = tensors[2]
encoder_output.append(f"new_{name}")
# (downsample_left, batch_size, value_dim)
name = f'cached_val2_{i}'
encoder_input[f"{name}"] = tensors[3].numpy() if isinstance(tensors[3], torch.Tensor) else tensors[3]
name = f"cached_val2_{i}"
encoder_input[name] = tensors[3]
encoder_output.append(f"new_{name}")
# (batch_size, embed_dim, conv_left_pad)
name = f'cached_conv1_{i}'
encoder_input[f"{name}"] = tensors[4].numpy() if isinstance(tensors[4], torch.Tensor) else tensors[4]
name = f"cached_conv1_{i}"
encoder_input[name] = tensors[4]
encoder_output.append(f"new_{name}")
# (batch_size, embed_dim, conv_left_pad)
name = f'cached_conv2_{i}'
encoder_input[f"{name}"] = tensors[5].numpy() if isinstance(tensors[5], torch.Tensor) else tensors[5]
name = f"cached_conv2_{i}"
encoder_input[name] = tensors[5]
encoder_output.append(f"new_{name}")
for i in range(len(self.states[:-2]) // 6):
build_inputs_outputs(self.states[i*6:(i+1)*6], i)
build_inputs_outputs(self.states[i * 6 : (i + 1) * 6], i)
# (batch_size, channels, left_pad, freq)
name = f'embed_states'
name = "embed_states"
embed_states = self.states[-2]
encoder_input[f"{name}"] = embed_states.numpy() if isinstance(embed_states, torch.Tensor) else embed_states
encoder_input[name] = embed_states
encoder_output.append(f"new_{name}")
# (batch_size,)
name = f'processed_lens'
name = "processed_lens"
processed_lens = self.states[-1]
encoder_input[f"{name}"] = processed_lens.numpy() if isinstance(processed_lens, torch.Tensor) else processed_lens
encoder_input[name] = processed_lens
encoder_output.append(f"new_{name}")
return encoder_input, encoder_output