mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge pull request #6 from csukuangfj/small-fixes-online-zipformer2
Small fixes
This commit is contained in:
commit
31ba3418cc
@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
|
||||
|
||||
"""
|
||||
This script exports a transducer model from PyTorch to ONNX.
|
||||
@ -47,8 +48,12 @@ popd
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512 \
|
||||
--causal True \
|
||||
--chunk-size "16,32,64,-1" \
|
||||
--left-context-frames "64,128,256,-1"
|
||||
--chunk-size 16 \
|
||||
--left-context-frames 64
|
||||
|
||||
The --chunk-size in training is "16,32,64,-1", so we select one of them
|
||||
(excluding -1) during streaming export. The same applies to `--left-context`,
|
||||
whose value is "64,128,256,-1".
|
||||
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
@ -176,7 +181,9 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||
class OnnxEncoder(nn.Module):
|
||||
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
|
||||
|
||||
def __init__(self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear):
|
||||
def __init__(
|
||||
self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
encoder:
|
||||
@ -335,7 +342,7 @@ def export_encoder_model_onnx(
|
||||
# The encoder_embed subsample features (T - 7) // 2
|
||||
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
|
||||
T = decode_chunk_len + encoder_model.pad_length
|
||||
|
||||
|
||||
x = torch.rand(1, T, 80, dtype=torch.float32)
|
||||
init_state = encoder_model.get_init_states()
|
||||
num_encoders = len(encoder_model.encoder.encoder_dim)
|
||||
@ -347,58 +354,58 @@ def export_encoder_model_onnx(
|
||||
|
||||
outputs = {}
|
||||
output_names = ["encoder_out"]
|
||||
|
||||
def build_inputs_outputs(tensors, i):
|
||||
assert len(tensors) == 6, len(tensors)
|
||||
|
||||
# (downsample_left, batch_size, key_dim)
|
||||
name = f'cached_key_{i}'
|
||||
name = f"cached_key_{i}"
|
||||
logging.info(f"{name}.shape: {tensors[0].shape}")
|
||||
inputs[f"{name}"] = {1: "N"}
|
||||
inputs[name] = {1: "N"}
|
||||
outputs[f"new_{name}"] = {1: "N"}
|
||||
input_names.append(f"{name}")
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
# (1, batch_size, downsample_left, nonlin_attn_head_dim)
|
||||
name = f'cached_nonlin_attn_{i}'
|
||||
name = f"cached_nonlin_attn_{i}"
|
||||
logging.info(f"{name}.shape: {tensors[1].shape}")
|
||||
inputs[f"{name}"] = {1: "N"}
|
||||
inputs[name] = {1: "N"}
|
||||
outputs[f"new_{name}"] = {1: "N"}
|
||||
input_names.append(f"{name}")
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
# (downsample_left, batch_size, value_dim)
|
||||
name = f'cached_val1_{i}'
|
||||
name = f"cached_val1_{i}"
|
||||
logging.info(f"{name}.shape: {tensors[2].shape}")
|
||||
inputs[f"{name}"] = {1: "N"}
|
||||
inputs[name] = {1: "N"}
|
||||
outputs[f"new_{name}"] = {1: "N"}
|
||||
input_names.append(f"{name}")
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
# (downsample_left, batch_size, value_dim)
|
||||
name = f'cached_val2_{i}'
|
||||
name = f"cached_val2_{i}"
|
||||
logging.info(f"{name}.shape: {tensors[3].shape}")
|
||||
inputs[f"{name}"] = {1: "N"}
|
||||
inputs[name] = {1: "N"}
|
||||
outputs[f"new_{name}"] = {1: "N"}
|
||||
input_names.append(f"{name}")
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
# (batch_size, embed_dim, conv_left_pad)
|
||||
name = f'cached_conv1_{i}'
|
||||
name = f"cached_conv1_{i}"
|
||||
logging.info(f"{name}.shape: {tensors[4].shape}")
|
||||
inputs[f"{name}"] = {0: "N"}
|
||||
inputs[name] = {0: "N"}
|
||||
outputs[f"new_{name}"] = {0: "N"}
|
||||
input_names.append(f"{name}")
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
# (batch_size, embed_dim, conv_left_pad)
|
||||
name = f'cached_conv2_{i}'
|
||||
name = f"cached_conv2_{i}"
|
||||
logging.info(f"{name}.shape: {tensors[5].shape}")
|
||||
inputs[f"{name}"] = {0: "N"}
|
||||
inputs[name] = {0: "N"}
|
||||
outputs[f"new_{name}"] = {0: "N"}
|
||||
input_names.append(f"{name}")
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
|
||||
num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers))
|
||||
encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dim))
|
||||
cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernel))
|
||||
@ -411,10 +418,10 @@ def export_encoder_model_onnx(
|
||||
num_heads = ",".join(map(str, encoder_model.encoder.num_heads))
|
||||
|
||||
meta_data = {
|
||||
"model_type": "zipformer",
|
||||
"model_type": "zipformer2",
|
||||
"version": "1",
|
||||
"model_author": "k2-fsa",
|
||||
"comment": "zipformer",
|
||||
"comment": "streaming zipformer2",
|
||||
"decode_chunk_len": str(decode_chunk_len), # 32
|
||||
"T": str(T), # 32+7+2*3=45
|
||||
"num_encoder_layers": num_encoder_layers,
|
||||
@ -428,25 +435,24 @@ def export_encoder_model_onnx(
|
||||
logging.info(f"meta_data: {meta_data}")
|
||||
|
||||
for i in range(len(init_state[:-2]) // 6):
|
||||
build_inputs_outputs(init_state[i*6:(i+1)*6], i)
|
||||
|
||||
build_inputs_outputs(init_state[i * 6 : (i + 1) * 6], i)
|
||||
|
||||
# (batch_size, channels, left_pad, freq)
|
||||
embed_states = init_state[-2]
|
||||
name = f'embed_states'
|
||||
name = "embed_states"
|
||||
logging.info(f"{name}.shape: {embed_states.shape}")
|
||||
inputs[f"{name}"] = {0: "N"}
|
||||
inputs[name] = {0: "N"}
|
||||
outputs[f"new_{name}"] = {0: "N"}
|
||||
input_names.append(f"{name}")
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
# (batch_size,)
|
||||
processed_lens = init_state[-1]
|
||||
name = f'processed_lens'
|
||||
name = "processed_lens"
|
||||
logging.info(f"{name}.shape: {processed_lens.shape}")
|
||||
inputs[f"{name}"] = {0: "N"}
|
||||
inputs[name] = {0: "N"}
|
||||
outputs[f"new_{name}"] = {0: "N"}
|
||||
input_names.append(f"{name}")
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
logging.info(inputs)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
|
||||
|
||||
"""
|
||||
This script exports a transducer model from PyTorch to ONNX.
|
||||
@ -178,7 +179,9 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||
class OnnxEncoder(nn.Module):
|
||||
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
|
||||
|
||||
def __init__(self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear):
|
||||
def __init__(
|
||||
self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
encoder:
|
||||
@ -315,10 +318,10 @@ def export_encoder_model_onnx(
|
||||
)
|
||||
|
||||
meta_data = {
|
||||
"model_type": "zipformer",
|
||||
"model_type": "zipformer2",
|
||||
"version": "1",
|
||||
"model_author": "k2-fsa",
|
||||
"comment": "zipformer",
|
||||
"comment": "non-streaming zipformer2",
|
||||
}
|
||||
logging.info(f"meta_data: {meta_data}")
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
|
||||
|
||||
"""
|
||||
This script loads ONNX models exported by ./export-onnx-streaming.py
|
||||
@ -47,8 +48,8 @@ popd
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512 \
|
||||
--causal True \
|
||||
--chunk-size "16,32,64,-1" \
|
||||
--left-context-frames "64,128,256,-1"
|
||||
--chunk-size 16 \
|
||||
--left-context-frames 64
|
||||
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
@ -154,7 +155,7 @@ class OnnxModel:
|
||||
logging.info(f"encoder_meta={encoder_meta}")
|
||||
|
||||
model_type = encoder_meta["model_type"]
|
||||
assert model_type == "zipformer", model_type
|
||||
assert model_type == "zipformer2", model_type
|
||||
|
||||
decode_chunk_len = int(encoder_meta["decode_chunk_len"])
|
||||
T = int(encoder_meta["T"])
|
||||
@ -200,16 +201,31 @@ class OnnxModel:
|
||||
conv_left_pad = cnn_module_kernels[i] // 2
|
||||
|
||||
for layer in range(num_layers):
|
||||
cached_key = torch.zeros(left_context_len[i], batch_size, key_dim)
|
||||
cached_nonlin_attn = torch.zeros(1, batch_size, left_context_len[i], nonlin_attn_head_dim)
|
||||
cached_val1 = torch.zeros(left_context_len[i], batch_size, value_dim)
|
||||
cached_val2 = torch.zeros(left_context_len[i], batch_size, value_dim)
|
||||
cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad)
|
||||
cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad)
|
||||
self.states += [cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2]
|
||||
embed_states = torch.zeros(batch_size, 128, 3, 19)
|
||||
cached_key = torch.zeros(
|
||||
left_context_len[i], batch_size, key_dim
|
||||
).numpy()
|
||||
cached_nonlin_attn = torch.zeros(
|
||||
1, batch_size, left_context_len[i], nonlin_attn_head_dim
|
||||
).numpy()
|
||||
cached_val1 = torch.zeros(
|
||||
left_context_len[i], batch_size, value_dim
|
||||
).numpy()
|
||||
cached_val2 = torch.zeros(
|
||||
left_context_len[i], batch_size, value_dim
|
||||
).numpy()
|
||||
cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy()
|
||||
cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy()
|
||||
self.states += [
|
||||
cached_key,
|
||||
cached_nonlin_attn,
|
||||
cached_val1,
|
||||
cached_val2,
|
||||
cached_conv1,
|
||||
cached_conv2,
|
||||
]
|
||||
embed_states = torch.zeros(batch_size, 128, 3, 19).numpy()
|
||||
self.states.append(embed_states)
|
||||
processed_lens = torch.zeros(batch_size, dtype=torch.int64)
|
||||
processed_lens = torch.zeros(batch_size, dtype=torch.int64).numpy()
|
||||
self.states.append(processed_lens)
|
||||
|
||||
self.num_encoders = num_encoders
|
||||
@ -252,48 +268,48 @@ class OnnxModel:
|
||||
assert len(tensors) == 6, len(tensors)
|
||||
|
||||
# (downsample_left, batch_size, key_dim)
|
||||
name = f'cached_key_{i}'
|
||||
encoder_input[f"{name}"] = tensors[0].numpy() if isinstance(tensors[0], torch.Tensor) else tensors[0]
|
||||
name = f"cached_key_{i}"
|
||||
encoder_input[name] = tensors[0]
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
# (1, batch_size, downsample_left, nonlin_attn_head_dim)
|
||||
name = f'cached_nonlin_attn_{i}'
|
||||
encoder_input[f"{name}"] = tensors[1].numpy() if isinstance(tensors[1], torch.Tensor) else tensors[1]
|
||||
name = f"cached_nonlin_attn_{i}"
|
||||
encoder_input[name] = tensors[1]
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
# (downsample_left, batch_size, value_dim)
|
||||
name = f'cached_val1_{i}'
|
||||
encoder_input[f"{name}"] = tensors[2].numpy() if isinstance(tensors[2], torch.Tensor) else tensors[2]
|
||||
name = f"cached_val1_{i}"
|
||||
encoder_input[name] = tensors[2]
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
# (downsample_left, batch_size, value_dim)
|
||||
name = f'cached_val2_{i}'
|
||||
encoder_input[f"{name}"] = tensors[3].numpy() if isinstance(tensors[3], torch.Tensor) else tensors[3]
|
||||
name = f"cached_val2_{i}"
|
||||
encoder_input[name] = tensors[3]
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
# (batch_size, embed_dim, conv_left_pad)
|
||||
name = f'cached_conv1_{i}'
|
||||
encoder_input[f"{name}"] = tensors[4].numpy() if isinstance(tensors[4], torch.Tensor) else tensors[4]
|
||||
name = f"cached_conv1_{i}"
|
||||
encoder_input[name] = tensors[4]
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
# (batch_size, embed_dim, conv_left_pad)
|
||||
name = f'cached_conv2_{i}'
|
||||
encoder_input[f"{name}"] = tensors[5].numpy() if isinstance(tensors[5], torch.Tensor) else tensors[5]
|
||||
name = f"cached_conv2_{i}"
|
||||
encoder_input[name] = tensors[5]
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
for i in range(len(self.states[:-2]) // 6):
|
||||
build_inputs_outputs(self.states[i*6:(i+1)*6], i)
|
||||
build_inputs_outputs(self.states[i * 6 : (i + 1) * 6], i)
|
||||
|
||||
# (batch_size, channels, left_pad, freq)
|
||||
name = f'embed_states'
|
||||
name = "embed_states"
|
||||
embed_states = self.states[-2]
|
||||
encoder_input[f"{name}"] = embed_states.numpy() if isinstance(embed_states, torch.Tensor) else embed_states
|
||||
encoder_input[name] = embed_states
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
# (batch_size,)
|
||||
name = f'processed_lens'
|
||||
name = "processed_lens"
|
||||
processed_lens = self.states[-1]
|
||||
encoder_input[f"{name}"] = processed_lens.numpy() if isinstance(processed_lens, torch.Tensor) else processed_lens
|
||||
encoder_input[name] = processed_lens
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
return encoder_input, encoder_output
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user