mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Fix ONNX export of the latest streaming zipformer model. (#1148)
This commit is contained in:
parent
219bba1310
commit
968ebd236b
@ -86,7 +86,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import make_pad_mask, str2bool
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -218,7 +218,7 @@ class OnnxEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size)
|
assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size)
|
||||||
|
|
||||||
src_key_padding_mask = make_pad_mask(x_lens)
|
src_key_padding_mask = torch.zeros(N, self.chunk_size, dtype=torch.bool)
|
||||||
|
|
||||||
# processed_mask is used to mask out initial states
|
# processed_mask is used to mask out initial states
|
||||||
processed_mask = torch.arange(left_context_len, device=x.device).expand(
|
processed_mask = torch.arange(left_context_len, device=x.device).expand(
|
||||||
@ -272,6 +272,7 @@ class OnnxEncoder(nn.Module):
|
|||||||
states = self.encoder.get_init_states(batch_size, device)
|
states = self.encoder.get_init_states(batch_size, device)
|
||||||
|
|
||||||
embed_states = self.encoder_embed.get_init_states(batch_size, device)
|
embed_states = self.encoder_embed.get_init_states(batch_size, device)
|
||||||
|
|
||||||
states.append(embed_states)
|
states.append(embed_states)
|
||||||
|
|
||||||
processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device)
|
processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user