From 4cb239518655661f5903ef6c17be08f57a1a487b Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 6 Jun 2023 14:20:12 +0800 Subject: [PATCH] Support longer input for the offline model --- egs/librispeech/ASR/zipformer/export-onnx.py | 15 ++++++++++++++- egs/librispeech/ASR/zipformer/zipformer.py | 6 ------ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/export-onnx.py b/egs/librispeech/ASR/zipformer/export-onnx.py index fea669fd0..e8777b13c 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx.py +++ b/egs/librispeech/ASR/zipformer/export-onnx.py @@ -74,7 +74,7 @@ from decoder import Decoder from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model -from zipformer import Zipformer2 +from zipformer import Zipformer2, CompactRelPositionalEncoding from icefall.checkpoint import ( average_checkpoints, @@ -296,6 +296,19 @@ def export_encoder_model_onnx( x = torch.zeros(1, 100, 80, dtype=torch.float32) x_lens = torch.tensor([100], dtype=torch.int64) + # It assumes that the maximum input, after downsampling, won't have more + # than 10k frames. + # The first downsampling factor is 2, so the maximum input + # should contain less than 20k frames, e.g., less than 400 seconds, + # i.e., 3.3 minutes + # + # Note: If you want to handle a longer input audio, please increase this + # value. The downside is that it will increase the size of the model. + max_len = 10000 + for name, m in encoder_model.named_modules(): + if isinstance(m, CompactRelPositionalEncoding): + m.extend_pe(torch.tensor(0.0).expand(max_len)) + torch.onnx.export( encoder_model, (x, x_lens), diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 85ebdb56e..15022947f 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1305,12 +1305,6 @@ class CompactRelPositionalEncoding(torch.nn.Module): ) -> None: """Construct a CompactRelPositionalEncoding object.""" super(CompactRelPositionalEncoding, self).__init__() - if torch.jit.is_tracing: - # 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e., - # It assumes that the maximum input won't have more than - # 10k frames. - # - max_len = 10000 self.embed_dim = embed_dim assert embed_dim % 2 == 0 self.dropout = Dropout2(dropout_rate)