Use torch.jit.script for the LSTM transducer decoder model

This commit is contained in:
Fangjun Kuang 2024-01-26 16:07:16 +08:00
parent e1880b7413
commit 7a4d8c9c1d

View File

@ -218,10 +218,9 @@ def export_decoder_model_jit_trace(
decoder_filename:
The filename to save the exported model.
"""
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
need_pad = torch.tensor([False])
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
# TODO(fangjun): Change the function name since we are actually using
# torch.jit.script instead of torch.jit.trace
traced_model = torch.jit.script(decoder_model)
traced_model.save(decoder_filename)
logging.info(f"Saved to {decoder_filename}")