solve zipformer streaming gpu inference

This commit is contained in:
Anjos 2023-03-23 14:34:47 +08:00 committed by GitHub
parent d74822d07b
commit 1559f9c0a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -132,7 +132,8 @@ def export_encoder_model_jit_trace(
states = encoder_model.get_init_state(device=x.device)
encoder_model.__class__.forward = encoder_model.__class__.streaming_forward
traced_model = torch.jit.trace(encoder_model, (x, x_lens, states))
# traced_model = torch.jit.trace(encoder_model, (x, x_lens, states))
traced_model = torch.jit.script(encoder_model)
traced_model.save(encoder_filename)
logging.info(f"Saved to {encoder_filename}")