mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Fix ONNX export of the latest streaming zipformer model.
This commit is contained in:
parent
219bba1310
commit
d876131205
@ -218,7 +218,7 @@ class OnnxEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size)
|
assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size)
|
||||||
|
|
||||||
src_key_padding_mask = make_pad_mask(x_lens)
|
src_key_padding_mask = torch.zeros(N, self.chunk_size, dtype=torch.bool)
|
||||||
|
|
||||||
# processed_mask is used to mask out initial states
|
# processed_mask is used to mask out initial states
|
||||||
processed_mask = torch.arange(left_context_len, device=x.device).expand(
|
processed_mask = torch.arange(left_context_len, device=x.device).expand(
|
||||||
@ -272,6 +272,7 @@ class OnnxEncoder(nn.Module):
|
|||||||
states = self.encoder.get_init_states(batch_size, device)
|
states = self.encoder.get_init_states(batch_size, device)
|
||||||
|
|
||||||
embed_states = self.encoder_embed.get_init_states(batch_size, device)
|
embed_states = self.encoder_embed.get_init_states(batch_size, device)
|
||||||
|
|
||||||
states.append(embed_states)
|
states.append(embed_states)
|
||||||
|
|
||||||
processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device)
|
processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user