Merge pull request #6 from csukuangfj/small-fixes-online-zipformer2

Small fixes
This commit is contained in:
danfu 2023-06-12 13:19:24 +08:00 committed by GitHub
commit 31ba3418cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 90 additions and 65 deletions

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# #
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) # Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
""" """
This script exports a transducer model from PyTorch to ONNX. This script exports a transducer model from PyTorch to ONNX.
@ -47,8 +48,12 @@ popd
--decoder-dim 512 \ --decoder-dim 512 \
--joiner-dim 512 \ --joiner-dim 512 \
--causal True \ --causal True \
--chunk-size "16,32,64,-1" \ --chunk-size 16 \
--left-context-frames "64,128,256,-1" --left-context-frames 64
The --chunk-size in training is "16,32,64,-1", so we select one of them
(excluding -1) during streaming export. The same applies to `--left-context`,
whose value is "64,128,256,-1".
It will generate the following 3 files inside $repo/exp: It will generate the following 3 files inside $repo/exp:
@ -176,7 +181,9 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]):
class OnnxEncoder(nn.Module): class OnnxEncoder(nn.Module):
"""A wrapper for Zipformer and the encoder_proj from the joiner""" """A wrapper for Zipformer and the encoder_proj from the joiner"""
def __init__(self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear): def __init__(
self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear
):
""" """
Args: Args:
encoder: encoder:
@ -335,7 +342,7 @@ def export_encoder_model_onnx(
# The encoder_embed subsample features (T - 7) // 2 # The encoder_embed subsample features (T - 7) // 2
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
T = decode_chunk_len + encoder_model.pad_length T = decode_chunk_len + encoder_model.pad_length
x = torch.rand(1, T, 80, dtype=torch.float32) x = torch.rand(1, T, 80, dtype=torch.float32)
init_state = encoder_model.get_init_states() init_state = encoder_model.get_init_states()
num_encoders = len(encoder_model.encoder.encoder_dim) num_encoders = len(encoder_model.encoder.encoder_dim)
@ -347,58 +354,58 @@ def export_encoder_model_onnx(
outputs = {} outputs = {}
output_names = ["encoder_out"] output_names = ["encoder_out"]
def build_inputs_outputs(tensors, i): def build_inputs_outputs(tensors, i):
assert len(tensors) == 6, len(tensors) assert len(tensors) == 6, len(tensors)
# (downsample_left, batch_size, key_dim) # (downsample_left, batch_size, key_dim)
name = f'cached_key_{i}' name = f"cached_key_{i}"
logging.info(f"{name}.shape: {tensors[0].shape}") logging.info(f"{name}.shape: {tensors[0].shape}")
inputs[f"{name}"] = {1: "N"} inputs[name] = {1: "N"}
outputs[f"new_{name}"] = {1: "N"} outputs[f"new_{name}"] = {1: "N"}
input_names.append(f"{name}") input_names.append(name)
output_names.append(f"new_{name}") output_names.append(f"new_{name}")
# (1, batch_size, downsample_left, nonlin_attn_head_dim) # (1, batch_size, downsample_left, nonlin_attn_head_dim)
name = f'cached_nonlin_attn_{i}' name = f"cached_nonlin_attn_{i}"
logging.info(f"{name}.shape: {tensors[1].shape}") logging.info(f"{name}.shape: {tensors[1].shape}")
inputs[f"{name}"] = {1: "N"} inputs[name] = {1: "N"}
outputs[f"new_{name}"] = {1: "N"} outputs[f"new_{name}"] = {1: "N"}
input_names.append(f"{name}") input_names.append(name)
output_names.append(f"new_{name}") output_names.append(f"new_{name}")
# (downsample_left, batch_size, value_dim) # (downsample_left, batch_size, value_dim)
name = f'cached_val1_{i}' name = f"cached_val1_{i}"
logging.info(f"{name}.shape: {tensors[2].shape}") logging.info(f"{name}.shape: {tensors[2].shape}")
inputs[f"{name}"] = {1: "N"} inputs[name] = {1: "N"}
outputs[f"new_{name}"] = {1: "N"} outputs[f"new_{name}"] = {1: "N"}
input_names.append(f"{name}") input_names.append(name)
output_names.append(f"new_{name}") output_names.append(f"new_{name}")
# (downsample_left, batch_size, value_dim) # (downsample_left, batch_size, value_dim)
name = f'cached_val2_{i}' name = f"cached_val2_{i}"
logging.info(f"{name}.shape: {tensors[3].shape}") logging.info(f"{name}.shape: {tensors[3].shape}")
inputs[f"{name}"] = {1: "N"} inputs[name] = {1: "N"}
outputs[f"new_{name}"] = {1: "N"} outputs[f"new_{name}"] = {1: "N"}
input_names.append(f"{name}") input_names.append(name)
output_names.append(f"new_{name}") output_names.append(f"new_{name}")
# (batch_size, embed_dim, conv_left_pad) # (batch_size, embed_dim, conv_left_pad)
name = f'cached_conv1_{i}' name = f"cached_conv1_{i}"
logging.info(f"{name}.shape: {tensors[4].shape}") logging.info(f"{name}.shape: {tensors[4].shape}")
inputs[f"{name}"] = {0: "N"} inputs[name] = {0: "N"}
outputs[f"new_{name}"] = {0: "N"} outputs[f"new_{name}"] = {0: "N"}
input_names.append(f"{name}") input_names.append(name)
output_names.append(f"new_{name}") output_names.append(f"new_{name}")
# (batch_size, embed_dim, conv_left_pad) # (batch_size, embed_dim, conv_left_pad)
name = f'cached_conv2_{i}' name = f"cached_conv2_{i}"
logging.info(f"{name}.shape: {tensors[5].shape}") logging.info(f"{name}.shape: {tensors[5].shape}")
inputs[f"{name}"] = {0: "N"} inputs[name] = {0: "N"}
outputs[f"new_{name}"] = {0: "N"} outputs[f"new_{name}"] = {0: "N"}
input_names.append(f"{name}") input_names.append(name)
output_names.append(f"new_{name}") output_names.append(f"new_{name}")
num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers)) num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers))
encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dim)) encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dim))
cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernel)) cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernel))
@ -411,10 +418,10 @@ def export_encoder_model_onnx(
num_heads = ",".join(map(str, encoder_model.encoder.num_heads)) num_heads = ",".join(map(str, encoder_model.encoder.num_heads))
meta_data = { meta_data = {
"model_type": "zipformer", "model_type": "zipformer2",
"version": "1", "version": "1",
"model_author": "k2-fsa", "model_author": "k2-fsa",
"comment": "zipformer", "comment": "streaming zipformer2",
"decode_chunk_len": str(decode_chunk_len), # 32 "decode_chunk_len": str(decode_chunk_len), # 32
"T": str(T), # 32+7+2*3=45 "T": str(T), # 32+7+2*3=45
"num_encoder_layers": num_encoder_layers, "num_encoder_layers": num_encoder_layers,
@ -428,25 +435,24 @@ def export_encoder_model_onnx(
logging.info(f"meta_data: {meta_data}") logging.info(f"meta_data: {meta_data}")
for i in range(len(init_state[:-2]) // 6): for i in range(len(init_state[:-2]) // 6):
build_inputs_outputs(init_state[i*6:(i+1)*6], i) build_inputs_outputs(init_state[i * 6 : (i + 1) * 6], i)
# (batch_size, channels, left_pad, freq) # (batch_size, channels, left_pad, freq)
embed_states = init_state[-2] embed_states = init_state[-2]
name = f'embed_states' name = "embed_states"
logging.info(f"{name}.shape: {embed_states.shape}") logging.info(f"{name}.shape: {embed_states.shape}")
inputs[f"{name}"] = {0: "N"} inputs[name] = {0: "N"}
outputs[f"new_{name}"] = {0: "N"} outputs[f"new_{name}"] = {0: "N"}
input_names.append(f"{name}") input_names.append(name)
output_names.append(f"new_{name}") output_names.append(f"new_{name}")
# (batch_size,) # (batch_size,)
processed_lens = init_state[-1] processed_lens = init_state[-1]
name = f'processed_lens' name = "processed_lens"
logging.info(f"{name}.shape: {processed_lens.shape}") logging.info(f"{name}.shape: {processed_lens.shape}")
inputs[f"{name}"] = {0: "N"} inputs[name] = {0: "N"}
outputs[f"new_{name}"] = {0: "N"} outputs[f"new_{name}"] = {0: "N"}
input_names.append(f"{name}") input_names.append(name)
output_names.append(f"new_{name}") output_names.append(f"new_{name}")
logging.info(inputs) logging.info(inputs)

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# #
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) # Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
""" """
This script exports a transducer model from PyTorch to ONNX. This script exports a transducer model from PyTorch to ONNX.
@ -178,7 +179,9 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]):
class OnnxEncoder(nn.Module): class OnnxEncoder(nn.Module):
"""A wrapper for Zipformer and the encoder_proj from the joiner""" """A wrapper for Zipformer and the encoder_proj from the joiner"""
def __init__(self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear): def __init__(
self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear
):
""" """
Args: Args:
encoder: encoder:
@ -315,10 +318,10 @@ def export_encoder_model_onnx(
) )
meta_data = { meta_data = {
"model_type": "zipformer", "model_type": "zipformer2",
"version": "1", "version": "1",
"model_author": "k2-fsa", "model_author": "k2-fsa",
"comment": "zipformer", "comment": "non-streaming zipformer2",
} }
logging.info(f"meta_data: {meta_data}") logging.info(f"meta_data: {meta_data}")

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) # Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
""" """
This script loads ONNX models exported by ./export-onnx-streaming.py This script loads ONNX models exported by ./export-onnx-streaming.py
@ -47,8 +48,8 @@ popd
--decoder-dim 512 \ --decoder-dim 512 \
--joiner-dim 512 \ --joiner-dim 512 \
--causal True \ --causal True \
--chunk-size "16,32,64,-1" \ --chunk-size 16 \
--left-context-frames "64,128,256,-1" --left-context-frames 64
It will generate the following 3 files inside $repo/exp: It will generate the following 3 files inside $repo/exp:
@ -154,7 +155,7 @@ class OnnxModel:
logging.info(f"encoder_meta={encoder_meta}") logging.info(f"encoder_meta={encoder_meta}")
model_type = encoder_meta["model_type"] model_type = encoder_meta["model_type"]
assert model_type == "zipformer", model_type assert model_type == "zipformer2", model_type
decode_chunk_len = int(encoder_meta["decode_chunk_len"]) decode_chunk_len = int(encoder_meta["decode_chunk_len"])
T = int(encoder_meta["T"]) T = int(encoder_meta["T"])
@ -200,16 +201,31 @@ class OnnxModel:
conv_left_pad = cnn_module_kernels[i] // 2 conv_left_pad = cnn_module_kernels[i] // 2
for layer in range(num_layers): for layer in range(num_layers):
cached_key = torch.zeros(left_context_len[i], batch_size, key_dim) cached_key = torch.zeros(
cached_nonlin_attn = torch.zeros(1, batch_size, left_context_len[i], nonlin_attn_head_dim) left_context_len[i], batch_size, key_dim
cached_val1 = torch.zeros(left_context_len[i], batch_size, value_dim) ).numpy()
cached_val2 = torch.zeros(left_context_len[i], batch_size, value_dim) cached_nonlin_attn = torch.zeros(
cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad) 1, batch_size, left_context_len[i], nonlin_attn_head_dim
cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad) ).numpy()
self.states += [cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2] cached_val1 = torch.zeros(
embed_states = torch.zeros(batch_size, 128, 3, 19) left_context_len[i], batch_size, value_dim
).numpy()
cached_val2 = torch.zeros(
left_context_len[i], batch_size, value_dim
).numpy()
cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy()
cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy()
self.states += [
cached_key,
cached_nonlin_attn,
cached_val1,
cached_val2,
cached_conv1,
cached_conv2,
]
embed_states = torch.zeros(batch_size, 128, 3, 19).numpy()
self.states.append(embed_states) self.states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int64) processed_lens = torch.zeros(batch_size, dtype=torch.int64).numpy()
self.states.append(processed_lens) self.states.append(processed_lens)
self.num_encoders = num_encoders self.num_encoders = num_encoders
@ -252,48 +268,48 @@ class OnnxModel:
assert len(tensors) == 6, len(tensors) assert len(tensors) == 6, len(tensors)
# (downsample_left, batch_size, key_dim) # (downsample_left, batch_size, key_dim)
name = f'cached_key_{i}' name = f"cached_key_{i}"
encoder_input[f"{name}"] = tensors[0].numpy() if isinstance(tensors[0], torch.Tensor) else tensors[0] encoder_input[name] = tensors[0]
encoder_output.append(f"new_{name}") encoder_output.append(f"new_{name}")
# (1, batch_size, downsample_left, nonlin_attn_head_dim) # (1, batch_size, downsample_left, nonlin_attn_head_dim)
name = f'cached_nonlin_attn_{i}' name = f"cached_nonlin_attn_{i}"
encoder_input[f"{name}"] = tensors[1].numpy() if isinstance(tensors[1], torch.Tensor) else tensors[1] encoder_input[name] = tensors[1]
encoder_output.append(f"new_{name}") encoder_output.append(f"new_{name}")
# (downsample_left, batch_size, value_dim) # (downsample_left, batch_size, value_dim)
name = f'cached_val1_{i}' name = f"cached_val1_{i}"
encoder_input[f"{name}"] = tensors[2].numpy() if isinstance(tensors[2], torch.Tensor) else tensors[2] encoder_input[name] = tensors[2]
encoder_output.append(f"new_{name}") encoder_output.append(f"new_{name}")
# (downsample_left, batch_size, value_dim) # (downsample_left, batch_size, value_dim)
name = f'cached_val2_{i}' name = f"cached_val2_{i}"
encoder_input[f"{name}"] = tensors[3].numpy() if isinstance(tensors[3], torch.Tensor) else tensors[3] encoder_input[name] = tensors[3]
encoder_output.append(f"new_{name}") encoder_output.append(f"new_{name}")
# (batch_size, embed_dim, conv_left_pad) # (batch_size, embed_dim, conv_left_pad)
name = f'cached_conv1_{i}' name = f"cached_conv1_{i}"
encoder_input[f"{name}"] = tensors[4].numpy() if isinstance(tensors[4], torch.Tensor) else tensors[4] encoder_input[name] = tensors[4]
encoder_output.append(f"new_{name}") encoder_output.append(f"new_{name}")
# (batch_size, embed_dim, conv_left_pad) # (batch_size, embed_dim, conv_left_pad)
name = f'cached_conv2_{i}' name = f"cached_conv2_{i}"
encoder_input[f"{name}"] = tensors[5].numpy() if isinstance(tensors[5], torch.Tensor) else tensors[5] encoder_input[name] = tensors[5]
encoder_output.append(f"new_{name}") encoder_output.append(f"new_{name}")
for i in range(len(self.states[:-2]) // 6): for i in range(len(self.states[:-2]) // 6):
build_inputs_outputs(self.states[i*6:(i+1)*6], i) build_inputs_outputs(self.states[i * 6 : (i + 1) * 6], i)
# (batch_size, channels, left_pad, freq) # (batch_size, channels, left_pad, freq)
name = f'embed_states' name = "embed_states"
embed_states = self.states[-2] embed_states = self.states[-2]
encoder_input[f"{name}"] = embed_states.numpy() if isinstance(embed_states, torch.Tensor) else embed_states encoder_input[name] = embed_states
encoder_output.append(f"new_{name}") encoder_output.append(f"new_{name}")
# (batch_size,) # (batch_size,)
name = f'processed_lens' name = "processed_lens"
processed_lens = self.states[-1] processed_lens = self.states[-1]
encoder_input[f"{name}"] = processed_lens.numpy() if isinstance(processed_lens, torch.Tensor) else processed_lens encoder_input[name] = processed_lens
encoder_output.append(f"new_{name}") encoder_output.append(f"new_{name}")
return encoder_input, encoder_output return encoder_input, encoder_output