From b8a17944e4a1f7a8b04830281affb0b97f26a100 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 5 Jul 2023 10:23:35 +0800 Subject: [PATCH] Fix zipformer CI test (#1164) --- .../ASR/pruned_transducer_stateless7_streaming/export.py | 4 ++++ .../pruned_transducer_stateless7_streaming/jit_pretrained.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py index 5735ee692..c191b5bcc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py @@ -856,6 +856,10 @@ def main(): # Otherwise, one of its arguments is a ragged tensor and is not # torch scriptabe. model.__class__.forward = torch.jit.ignore(model.__class__.forward) + model.encoder.__class__.non_streaming_forward = model.encoder.__class__.forward + model.encoder.__class__.non_streaming_forward = torch.jit.export( + model.encoder.__class__.non_streaming_forward + ) model.encoder.__class__.forward = model.encoder.__class__.streaming_forward logging.info("Using torch.jit.script") model = torch.jit.script(model) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py index 4fd5e1820..c8301b2da 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py @@ -252,7 +252,7 @@ def main(): feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( + encoder_out, encoder_out_lens = model.encoder.non_streaming_forward( x=features, x_lens=feature_lengths, )