Fix ONNX export of the latest streaming zipformer model. (#1148)

This commit is contained in:
Fangjun Kuang 2023-06-27 14:35:59 +08:00 committed by GitHub
parent 219bba1310
commit 968ebd236b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -86,7 +86,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import make_pad_mask, str2bool
from icefall.utils import str2bool
def get_parser():
@ -218,7 +218,7 @@ class OnnxEncoder(nn.Module):
)
assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size)
src_key_padding_mask = make_pad_mask(x_lens)
src_key_padding_mask = torch.zeros(N, self.chunk_size, dtype=torch.bool)
# processed_mask is used to mask out initial states
processed_mask = torch.arange(left_context_len, device=x.device).expand(
@ -272,6 +272,7 @@ class OnnxEncoder(nn.Module):
states = self.encoder.get_init_states(batch_size, device)
embed_states = self.encoder_embed.get_init_states(batch_size, device)
states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device)