mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 22:24:19 +00:00
Use torch.jit.script for the LSTM transducer decoder model
This commit is contained in:
parent
e1880b7413
commit
7a4d8c9c1d
@ -218,10 +218,9 @@ def export_decoder_model_jit_trace(
|
||||
decoder_filename:
|
||||
The filename to save the exported model.
|
||||
"""
|
||||
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||
need_pad = torch.tensor([False])
|
||||
|
||||
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
|
||||
# TODO(fangjun): Change the function name since we are actually using
|
||||
# torch.jit.script instead of torch.jit.trace
|
||||
traced_model = torch.jit.script(decoder_model)
|
||||
traced_model.save(decoder_filename)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user