From e32658e620eedeef3b82a7a1dc2f084e62f35861 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 17 Apr 2023 16:13:30 +0800 Subject: [PATCH] Fix torch.jit.script() export for streaming zipformer. (#1005) --- .../ASR/pruned_transducer_stateless7_streaming/export.py | 1 + .../ASR/pruned_transducer_stateless7_streaming/zipformer.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py index 1bc54fa26..5735ee692 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py @@ -856,6 +856,7 @@ def main(): # Otherwise, one of its arguments is a ragged tensor and is not # torch scriptabe. model.__class__.forward = torch.jit.ignore(model.__class__.forward) + model.encoder.__class__.forward = model.encoder.__class__.streaming_forward logging.info("Using torch.jit.script") model = torch.jit.script(model) filename = params.exp_dir / "cpu_jit.pt" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py index 0a6886dec..a5c422959 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -570,7 +570,6 @@ class Zipformer(EncoderInterface): return x, lengths - @torch.jit.export def streaming_forward( self, x: torch.Tensor,