mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
fix triton onnx export (#1730)
This commit is contained in:
parent
3fc06cc2b9
commit
3b434fe83c
@ -218,7 +218,7 @@ class OnnxEncoder(nn.Module):
|
|||||||
- encoder_out_lens, A 1-D tensor of shape (N,)
|
- encoder_out_lens, A 1-D tensor of shape (N,)
|
||||||
"""
|
"""
|
||||||
x, x_lens = self.encoder_embed(x, x_lens)
|
x, x_lens = self.encoder_embed(x, x_lens)
|
||||||
src_key_padding_mask = make_pad_mask(x_lens)
|
src_key_padding_mask = make_pad_mask(x_lens, x.shape[1])
|
||||||
x = x.permute(1, 0, 2)
|
x = x.permute(1, 0, 2)
|
||||||
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||||
encoder_out = encoder_out.permute(1, 0, 2)
|
encoder_out = encoder_out.permute(1, 0, 2)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user