fix triton onnx export (#1730)

This commit is contained in:
Yuekai Zhang 2024-08-23 09:33:46 +08:00 committed by GitHub
parent 3fc06cc2b9
commit 3b434fe83c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -218,7 +218,7 @@ class OnnxEncoder(nn.Module):
- encoder_out_lens, A 1-D tensor of shape (N,) - encoder_out_lens, A 1-D tensor of shape (N,)
""" """
x, x_lens = self.encoder_embed(x, x_lens) x, x_lens = self.encoder_embed(x, x_lens)
src_key_padding_mask = make_pad_mask(x_lens) src_key_padding_mask = make_pad_mask(x_lens, x.shape[1])
x = x.permute(1, 0, 2) x = x.permute(1, 0, 2)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) encoder_out = encoder_out.permute(1, 0, 2)