From d84e0867982596daa50e11aa771f568ba92b03d2 Mon Sep 17 00:00:00 2001 From: danqing fu Date: Sun, 11 Jun 2023 23:53:24 +0800 Subject: [PATCH] small fixes --- egs/librispeech/ASR/zipformer/export-onnx-streaming.py | 4 ++-- .../ASR/zipformer/onnx_pretrained-streaming.py | 8 ++------ 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index 65a8d7264..cfef9e6e5 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -246,7 +246,7 @@ class OnnxEncoder(nn.Module): new_processed_lens, ] - return encoder_out, encoder_out_lens, new_states + return encoder_out, new_states def get_init_states( self, @@ -266,7 +266,7 @@ class OnnxEncoder(nn.Module): embed_states = self.encoder_embed.get_init_states(batch_size, device) states.append(embed_states) - processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device) states.append(processed_lens) return states diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py index f74036745..b434a2e76 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py @@ -209,7 +209,7 @@ class OnnxModel: self.states += [cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2] embed_states = torch.zeros(batch_size, 128, 3, 19) self.states.append(embed_states) - processed_lens = torch.zeros(batch_size, dtype=torch.int32) + processed_lens = torch.zeros(batch_size, dtype=torch.int64) self.states.append(processed_lens) self.num_encoders = num_encoders @@ -296,7 +296,6 @@ class OnnxModel: encoder_input[f"{name}"] = processed_lens.numpy() if isinstance(processed_lens, torch.Tensor) else processed_lens encoder_output.append(f"new_{name}") - logging.info(f"encoder_output_len={len(encoder_output)}") return encoder_input, encoder_output def _update_states(self, states: List[np.ndarray]): @@ -312,11 +311,8 @@ class OnnxModel: T' is usually equal to ((T-7)//2+1)//2 """ encoder_input, encoder_output_names = self._build_encoder_input_output(x) - # logging.info(encoder_input.keys()) - # logging.info(encoder_output_names) + out = self.encoder.run(encoder_output_names, encoder_input) - len = out.pop(1) - out.append(len.astype(np.int32)) self._update_states(out[1:])