small fixes

This commit is contained in:
danqing fu 2023-06-08 09:47:00 +08:00
parent 4d4188aa81
commit 851bb50cb8
2 changed files with 44 additions and 15 deletions

View File

@ -190,6 +190,7 @@ class OnnxEncoder(nn.Module):
self.encoder_proj = encoder_proj
self.chunk_size = encoder.chunk_size[0]
self.left_context_len = encoder.left_context_frames[0]
self.pad_length = 7 + 2 * 3
def forward(
self,
@ -197,7 +198,7 @@ class OnnxEncoder(nn.Module):
states: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
N = x.size(0)
T = x.size(1)
T = self.chunk_size * 2 + self.pad_length
x_lens = torch.tensor([T] * N, device=x.device)
left_context_len = self.left_context_len
@ -333,8 +334,7 @@ def export_encoder_model_onnx(
decode_chunk_len = encoder_model.chunk_size * 2
# The encoder_embed subsample features (T - 7) // 2
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
pad_length = 7 + 2 * 3
T = decode_chunk_len + pad_length
T = decode_chunk_len + encoder_model.pad_length
x = torch.rand(1, T, 80, dtype=torch.float32)
init_state = encoder_model.get_init_states()
@ -394,12 +394,30 @@ def export_encoder_model_onnx(
# (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1)
cached_conv2 = init_state[5:-2:6]
build_inputs_outputs(cached_key, "cached_key", 2)
build_inputs_outputs(cached_nonlin_attn, "cached_nonlin_attn", 2)
build_inputs_outputs(cached_val1, "cached_val1", 2)
build_inputs_outputs(cached_val2, "cached_val2", 2)
build_inputs_outputs(cached_conv1, "cached_conv1", 1)
build_inputs_outputs(cached_conv2, "cached_conv2", 1)
build_inputs_outputs(cached_key, "cached_key", 1)
build_inputs_outputs(cached_nonlin_attn, "cached_nonlin_attn", 1)
build_inputs_outputs(cached_val1, "cached_val1", 1)
build_inputs_outputs(cached_val2, "cached_val2", 1)
build_inputs_outputs(cached_conv1, "cached_conv1", 0)
build_inputs_outputs(cached_conv2, "cached_conv2", 0)
# (batch_size, channels, left_pad, freq)
embed_states = init_state[-2]
name = f'embed_states'
logging.info(f"{name}.shape: {embed_states.shape}")
inputs[f"{name}"] = {0: "N"}
outputs[f"new_{name}"] = {0: "N"}
input_names.append(f"{name}")
output_names.append(f"new_{name}")
# (batch_size,)
processed_lens = init_state[-1]
name = f'processed_lens'
logging.info(f"{name}.shape: {processed_lens.shape}")
inputs[f"{name}"] = {0: "N"}
outputs[f"new_{name}"] = {0: "N"}
input_names.append(f"{name}")
output_names.append(f"new_{name}")
logging.info(inputs)
logging.info(outputs)

View File

@ -1678,15 +1678,26 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
# [where seq_len2 represents relative position.]
pos_scores = torch.matmul(p, pos_emb)
if torch.jit.is_tracing():
(num_heads, batch_size, time1, n) = pos_scores.shape
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
cols = torch.arange(seq_len)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
indexes = rows + cols
pos_scores = pos_scores.reshape(-1, n)
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
# the following .as_strided() expression converts the last axis of pos_scores from relative
# to absolute position. I don't know whether I might have got the time-offsets backwards or
# not, but let this code define which way round it is supposed to be.
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, k_len),
(pos_scores.stride(0),
pos_scores.stride(1),
pos_scores.stride(2)-pos_scores.stride(3),
pos_scores.stride(3)),
storage_offset=pos_scores.stride(3) * (seq_len - 1))
else:
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, k_len),
(pos_scores.stride(0),
pos_scores.stride(1),
pos_scores.stride(2)-pos_scores.stride(3),
pos_scores.stride(3)),
storage_offset=pos_scores.stride(3) * (seq_len - 1))
attn_scores = attn_scores + pos_scores