modify test_model.py, test if the model can be successfully exported by jit.trace

This commit is contained in:
yaozengwei 2022-08-10 16:20:25 +08:00
parent 522a45ce75
commit 8f3645e5cb

View File

@ -20,10 +20,18 @@
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless/test_model.py
python ./lstm_transducer_stateless/test_model.py
"""
import torch
import os
from pathlib import Path
from export import (
export_decoder_model_jit_trace,
export_encoder_model_jit_trace,
export_joiner_model_jit_trace,
)
from scaling_converter import convert_scaled_to_non_scaled
from train import get_params, get_transducer_model
@ -33,13 +41,33 @@ def test_model():
params.blank_id = 0
params.context_size = 2
params.unk_id = 2
params.encoder_dim = 512
params.rnn_hidden_size = 1024
params.num_encoder_layers = 12
params.aux_layer_period = 0
params.exp_dir = Path("exp_test_model")
model = get_transducer_model(params)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
torch.jit.script(model)
convert_scaled_to_non_scaled(model, inplace=True)
if not os.path.exists(params.exp_dir):
os.path.mkdir(params.exp_dir)
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
export_encoder_model_jit_trace(model.encoder, encoder_filename)
decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
export_decoder_model_jit_trace(model.decoder, decoder_filename)
joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
export_joiner_model_jit_trace(model.joiner, joiner_filename)
print("The model has been successfully exported using jit.trace.")
def main():