mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
small fixes
This commit is contained in:
parent
4d4188aa81
commit
851bb50cb8
@ -190,6 +190,7 @@ class OnnxEncoder(nn.Module):
|
|||||||
self.encoder_proj = encoder_proj
|
self.encoder_proj = encoder_proj
|
||||||
self.chunk_size = encoder.chunk_size[0]
|
self.chunk_size = encoder.chunk_size[0]
|
||||||
self.left_context_len = encoder.left_context_frames[0]
|
self.left_context_len = encoder.left_context_frames[0]
|
||||||
|
self.pad_length = 7 + 2 * 3
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -197,7 +198,7 @@ class OnnxEncoder(nn.Module):
|
|||||||
states: List[torch.Tensor],
|
states: List[torch.Tensor],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
||||||
N = x.size(0)
|
N = x.size(0)
|
||||||
T = x.size(1)
|
T = self.chunk_size * 2 + self.pad_length
|
||||||
x_lens = torch.tensor([T] * N, device=x.device)
|
x_lens = torch.tensor([T] * N, device=x.device)
|
||||||
left_context_len = self.left_context_len
|
left_context_len = self.left_context_len
|
||||||
|
|
||||||
@ -333,8 +334,7 @@ def export_encoder_model_onnx(
|
|||||||
decode_chunk_len = encoder_model.chunk_size * 2
|
decode_chunk_len = encoder_model.chunk_size * 2
|
||||||
# The encoder_embed subsample features (T - 7) // 2
|
# The encoder_embed subsample features (T - 7) // 2
|
||||||
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
|
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
|
||||||
pad_length = 7 + 2 * 3
|
T = decode_chunk_len + encoder_model.pad_length
|
||||||
T = decode_chunk_len + pad_length
|
|
||||||
|
|
||||||
x = torch.rand(1, T, 80, dtype=torch.float32)
|
x = torch.rand(1, T, 80, dtype=torch.float32)
|
||||||
init_state = encoder_model.get_init_states()
|
init_state = encoder_model.get_init_states()
|
||||||
@ -394,12 +394,30 @@ def export_encoder_model_onnx(
|
|||||||
# (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1)
|
# (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1)
|
||||||
cached_conv2 = init_state[5:-2:6]
|
cached_conv2 = init_state[5:-2:6]
|
||||||
|
|
||||||
build_inputs_outputs(cached_key, "cached_key", 2)
|
build_inputs_outputs(cached_key, "cached_key", 1)
|
||||||
build_inputs_outputs(cached_nonlin_attn, "cached_nonlin_attn", 2)
|
build_inputs_outputs(cached_nonlin_attn, "cached_nonlin_attn", 1)
|
||||||
build_inputs_outputs(cached_val1, "cached_val1", 2)
|
build_inputs_outputs(cached_val1, "cached_val1", 1)
|
||||||
build_inputs_outputs(cached_val2, "cached_val2", 2)
|
build_inputs_outputs(cached_val2, "cached_val2", 1)
|
||||||
build_inputs_outputs(cached_conv1, "cached_conv1", 1)
|
build_inputs_outputs(cached_conv1, "cached_conv1", 0)
|
||||||
build_inputs_outputs(cached_conv2, "cached_conv2", 1)
|
build_inputs_outputs(cached_conv2, "cached_conv2", 0)
|
||||||
|
|
||||||
|
# (batch_size, channels, left_pad, freq)
|
||||||
|
embed_states = init_state[-2]
|
||||||
|
name = f'embed_states'
|
||||||
|
logging.info(f"{name}.shape: {embed_states.shape}")
|
||||||
|
inputs[f"{name}"] = {0: "N"}
|
||||||
|
outputs[f"new_{name}"] = {0: "N"}
|
||||||
|
input_names.append(f"{name}")
|
||||||
|
output_names.append(f"new_{name}")
|
||||||
|
|
||||||
|
# (batch_size,)
|
||||||
|
processed_lens = init_state[-1]
|
||||||
|
name = f'processed_lens'
|
||||||
|
logging.info(f"{name}.shape: {processed_lens.shape}")
|
||||||
|
inputs[f"{name}"] = {0: "N"}
|
||||||
|
outputs[f"new_{name}"] = {0: "N"}
|
||||||
|
input_names.append(f"{name}")
|
||||||
|
output_names.append(f"new_{name}")
|
||||||
|
|
||||||
logging.info(inputs)
|
logging.info(inputs)
|
||||||
logging.info(outputs)
|
logging.info(outputs)
|
||||||
|
|||||||
@ -1678,15 +1678,26 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
||||||
# [where seq_len2 represents relative position.]
|
# [where seq_len2 represents relative position.]
|
||||||
pos_scores = torch.matmul(p, pos_emb)
|
pos_scores = torch.matmul(p, pos_emb)
|
||||||
|
|
||||||
|
if torch.jit.is_tracing():
|
||||||
|
(num_heads, batch_size, time1, n) = pos_scores.shape
|
||||||
|
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||||
|
cols = torch.arange(seq_len)
|
||||||
|
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
||||||
|
indexes = rows + cols
|
||||||
|
pos_scores = pos_scores.reshape(-1, n)
|
||||||
|
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
|
||||||
|
pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
|
||||||
# the following .as_strided() expression converts the last axis of pos_scores from relative
|
# the following .as_strided() expression converts the last axis of pos_scores from relative
|
||||||
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
||||||
# not, but let this code define which way round it is supposed to be.
|
# not, but let this code define which way round it is supposed to be.
|
||||||
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, k_len),
|
else:
|
||||||
(pos_scores.stride(0),
|
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, k_len),
|
||||||
pos_scores.stride(1),
|
(pos_scores.stride(0),
|
||||||
pos_scores.stride(2)-pos_scores.stride(3),
|
pos_scores.stride(1),
|
||||||
pos_scores.stride(3)),
|
pos_scores.stride(2)-pos_scores.stride(3),
|
||||||
storage_offset=pos_scores.stride(3) * (seq_len - 1))
|
pos_scores.stride(3)),
|
||||||
|
storage_offset=pos_scores.stride(3) * (seq_len - 1))
|
||||||
|
|
||||||
attn_scores = attn_scores + pos_scores
|
attn_scores = attn_scores + pos_scores
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user