diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py index 4af742316..2c21c16dc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py @@ -132,7 +132,8 @@ def export_encoder_model_jit_trace( states = encoder_model.get_init_state(device=x.device) encoder_model.__class__.forward = encoder_model.__class__.streaming_forward - traced_model = torch.jit.trace(encoder_model, (x, x_lens, states)) + # traced_model = torch.jit.trace(encoder_model, (x, x_lens, states)) + traced_model = torch.jit.script(encoder_model) traced_model.save(encoder_filename) logging.info(f"Saved to {encoder_filename}")