diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index 1e9ba5378..ebda2252f 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -190,6 +190,7 @@ class OnnxEncoder(nn.Module): self.encoder_proj = encoder_proj self.chunk_size = encoder.chunk_size[0] self.left_context_len = encoder.left_context_frames[0] + self.pad_length = 7 + 2 * 3 def forward( self, @@ -197,7 +198,7 @@ class OnnxEncoder(nn.Module): states: List[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: N = x.size(0) - T = x.size(1) + T = self.chunk_size * 2 + self.pad_length x_lens = torch.tensor([T] * N, device=x.device) left_context_len = self.left_context_len @@ -333,8 +334,7 @@ def export_encoder_model_onnx( decode_chunk_len = encoder_model.chunk_size * 2 # The encoder_embed subsample features (T - 7) // 2 # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling - pad_length = 7 + 2 * 3 - T = decode_chunk_len + pad_length + T = decode_chunk_len + encoder_model.pad_length x = torch.rand(1, T, 80, dtype=torch.float32) init_state = encoder_model.get_init_states() @@ -394,12 +394,30 @@ def export_encoder_model_onnx( # (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1) cached_conv2 = init_state[5:-2:6] - build_inputs_outputs(cached_key, "cached_key", 2) - build_inputs_outputs(cached_nonlin_attn, "cached_nonlin_attn", 2) - build_inputs_outputs(cached_val1, "cached_val1", 2) - build_inputs_outputs(cached_val2, "cached_val2", 2) - build_inputs_outputs(cached_conv1, "cached_conv1", 1) - build_inputs_outputs(cached_conv2, "cached_conv2", 1) + build_inputs_outputs(cached_key, "cached_key", 1) + build_inputs_outputs(cached_nonlin_attn, "cached_nonlin_attn", 1) + build_inputs_outputs(cached_val1, "cached_val1", 1) + build_inputs_outputs(cached_val2, "cached_val2", 1) + build_inputs_outputs(cached_conv1, "cached_conv1", 0) + build_inputs_outputs(cached_conv2, "cached_conv2", 0) + + # (batch_size, channels, left_pad, freq) + embed_states = init_state[-2] + name = f'embed_states' + logging.info(f"{name}.shape: {embed_states.shape}") + inputs[f"{name}"] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(f"{name}") + output_names.append(f"new_{name}") + + # (batch_size,) + processed_lens = init_state[-1] + name = f'processed_lens' + logging.info(f"{name}.shape: {processed_lens.shape}") + inputs[f"{name}"] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(f"{name}") + output_names.append(f"new_{name}") logging.info(inputs) logging.info(outputs) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 4d0d5fa98..cf810a298 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1678,15 +1678,26 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) # [where seq_len2 represents relative position.] pos_scores = torch.matmul(p, pos_emb) + + if torch.jit.is_tracing(): + (num_heads, batch_size, time1, n) = pos_scores.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(seq_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_scores = pos_scores.reshape(-1, n) + pos_scores = torch.gather(pos_scores, dim=1, index=indexes) + pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len) # the following .as_strided() expression converts the last axis of pos_scores from relative # to absolute position. I don't know whether I might have got the time-offsets backwards or # not, but let this code define which way round it is supposed to be. - pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, k_len), - (pos_scores.stride(0), - pos_scores.stride(1), - pos_scores.stride(2)-pos_scores.stride(3), - pos_scores.stride(3)), - storage_offset=pos_scores.stride(3) * (seq_len - 1)) + else: + pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, k_len), + (pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2)-pos_scores.stride(3), + pos_scores.stride(3)), + storage_offset=pos_scores.stride(3) * (seq_len - 1)) attn_scores = attn_scores + pos_scores