mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Fix ONNX export of the latest streaming zipformer model. (#1148)
This commit is contained in:
parent
219bba1310
commit
968ebd236b
@ -86,7 +86,7 @@ from icefall.checkpoint import (
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import make_pad_mask, str2bool
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -218,7 +218,7 @@ class OnnxEncoder(nn.Module):
|
||||
)
|
||||
assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size)
|
||||
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
src_key_padding_mask = torch.zeros(N, self.chunk_size, dtype=torch.bool)
|
||||
|
||||
# processed_mask is used to mask out initial states
|
||||
processed_mask = torch.arange(left_context_len, device=x.device).expand(
|
||||
@ -272,6 +272,7 @@ class OnnxEncoder(nn.Module):
|
||||
states = self.encoder.get_init_states(batch_size, device)
|
||||
|
||||
embed_states = self.encoder_embed.get_init_states(batch_size, device)
|
||||
|
||||
states.append(embed_states)
|
||||
|
||||
processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device)
|
||||
|
Loading…
x
Reference in New Issue
Block a user