diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py index 5735ee692..c191b5bcc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py @@ -856,6 +856,10 @@ def main(): # Otherwise, one of its arguments is a ragged tensor and is not # torch scriptabe. model.__class__.forward = torch.jit.ignore(model.__class__.forward) + model.encoder.__class__.non_streaming_forward = model.encoder.__class__.forward + model.encoder.__class__.non_streaming_forward = torch.jit.export( + model.encoder.__class__.non_streaming_forward + ) model.encoder.__class__.forward = model.encoder.__class__.streaming_forward logging.info("Using torch.jit.script") model = torch.jit.script(model) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py index 4fd5e1820..c8301b2da 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py @@ -252,7 +252,7 @@ def main(): feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( + encoder_out, encoder_out_lens = model.encoder.non_streaming_forward( x=features, x_lens=feature_lengths, )