mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
small fixes
This commit is contained in:
parent
4d4188aa81
commit
851bb50cb8
@ -190,6 +190,7 @@ class OnnxEncoder(nn.Module):
|
||||
self.encoder_proj = encoder_proj
|
||||
self.chunk_size = encoder.chunk_size[0]
|
||||
self.left_context_len = encoder.left_context_frames[0]
|
||||
self.pad_length = 7 + 2 * 3
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -197,7 +198,7 @@ class OnnxEncoder(nn.Module):
|
||||
states: List[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
||||
N = x.size(0)
|
||||
T = x.size(1)
|
||||
T = self.chunk_size * 2 + self.pad_length
|
||||
x_lens = torch.tensor([T] * N, device=x.device)
|
||||
left_context_len = self.left_context_len
|
||||
|
||||
@ -333,8 +334,7 @@ def export_encoder_model_onnx(
|
||||
decode_chunk_len = encoder_model.chunk_size * 2
|
||||
# The encoder_embed subsample features (T - 7) // 2
|
||||
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
|
||||
pad_length = 7 + 2 * 3
|
||||
T = decode_chunk_len + pad_length
|
||||
T = decode_chunk_len + encoder_model.pad_length
|
||||
|
||||
x = torch.rand(1, T, 80, dtype=torch.float32)
|
||||
init_state = encoder_model.get_init_states()
|
||||
@ -394,12 +394,30 @@ def export_encoder_model_onnx(
|
||||
# (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1)
|
||||
cached_conv2 = init_state[5:-2:6]
|
||||
|
||||
build_inputs_outputs(cached_key, "cached_key", 2)
|
||||
build_inputs_outputs(cached_nonlin_attn, "cached_nonlin_attn", 2)
|
||||
build_inputs_outputs(cached_val1, "cached_val1", 2)
|
||||
build_inputs_outputs(cached_val2, "cached_val2", 2)
|
||||
build_inputs_outputs(cached_conv1, "cached_conv1", 1)
|
||||
build_inputs_outputs(cached_conv2, "cached_conv2", 1)
|
||||
build_inputs_outputs(cached_key, "cached_key", 1)
|
||||
build_inputs_outputs(cached_nonlin_attn, "cached_nonlin_attn", 1)
|
||||
build_inputs_outputs(cached_val1, "cached_val1", 1)
|
||||
build_inputs_outputs(cached_val2, "cached_val2", 1)
|
||||
build_inputs_outputs(cached_conv1, "cached_conv1", 0)
|
||||
build_inputs_outputs(cached_conv2, "cached_conv2", 0)
|
||||
|
||||
# (batch_size, channels, left_pad, freq)
|
||||
embed_states = init_state[-2]
|
||||
name = f'embed_states'
|
||||
logging.info(f"{name}.shape: {embed_states.shape}")
|
||||
inputs[f"{name}"] = {0: "N"}
|
||||
outputs[f"new_{name}"] = {0: "N"}
|
||||
input_names.append(f"{name}")
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
# (batch_size,)
|
||||
processed_lens = init_state[-1]
|
||||
name = f'processed_lens'
|
||||
logging.info(f"{name}.shape: {processed_lens.shape}")
|
||||
inputs[f"{name}"] = {0: "N"}
|
||||
outputs[f"new_{name}"] = {0: "N"}
|
||||
input_names.append(f"{name}")
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
logging.info(inputs)
|
||||
logging.info(outputs)
|
||||
|
||||
@ -1678,15 +1678,26 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
||||
# [where seq_len2 represents relative position.]
|
||||
pos_scores = torch.matmul(p, pos_emb)
|
||||
|
||||
if torch.jit.is_tracing():
|
||||
(num_heads, batch_size, time1, n) = pos_scores.shape
|
||||
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||
cols = torch.arange(seq_len)
|
||||
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
||||
indexes = rows + cols
|
||||
pos_scores = pos_scores.reshape(-1, n)
|
||||
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
|
||||
pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
|
||||
# the following .as_strided() expression converts the last axis of pos_scores from relative
|
||||
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
||||
# not, but let this code define which way round it is supposed to be.
|
||||
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, k_len),
|
||||
(pos_scores.stride(0),
|
||||
pos_scores.stride(1),
|
||||
pos_scores.stride(2)-pos_scores.stride(3),
|
||||
pos_scores.stride(3)),
|
||||
storage_offset=pos_scores.stride(3) * (seq_len - 1))
|
||||
else:
|
||||
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, k_len),
|
||||
(pos_scores.stride(0),
|
||||
pos_scores.stride(1),
|
||||
pos_scores.stride(2)-pos_scores.stride(3),
|
||||
pos_scores.stride(3)),
|
||||
storage_offset=pos_scores.stride(3) * (seq_len - 1))
|
||||
|
||||
attn_scores = attn_scores + pos_scores
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user