Fix ONNX export of the latest streaming zipformer model.

This commit is contained in:
Fangjun Kuang 2023-06-27 14:31:06 +08:00
parent 219bba1310
commit d876131205

View File

@ -218,7 +218,7 @@ class OnnxEncoder(nn.Module):
) )
assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size) assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size)
src_key_padding_mask = make_pad_mask(x_lens) src_key_padding_mask = torch.zeros(N, self.chunk_size, dtype=torch.bool)
# processed_mask is used to mask out initial states # processed_mask is used to mask out initial states
processed_mask = torch.arange(left_context_len, device=x.device).expand( processed_mask = torch.arange(left_context_len, device=x.device).expand(
@ -272,6 +272,7 @@ class OnnxEncoder(nn.Module):
states = self.encoder.get_init_states(batch_size, device) states = self.encoder.get_init_states(batch_size, device)
embed_states = self.encoder_embed.get_init_states(batch_size, device) embed_states = self.encoder_embed.get_init_states(batch_size, device)
states.append(embed_states) states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device) processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device)