mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
modify test_model.py, test if the model can be successfully exported by jit.trace
This commit is contained in:
parent
522a45ce75
commit
8f3645e5cb
@ -20,10 +20,18 @@
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/librispeech/ASR
|
||||
python ./pruned_transducer_stateless/test_model.py
|
||||
python ./lstm_transducer_stateless/test_model.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from export import (
|
||||
export_decoder_model_jit_trace,
|
||||
export_encoder_model_jit_trace,
|
||||
export_joiner_model_jit_trace,
|
||||
)
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
|
||||
@ -33,13 +41,33 @@ def test_model():
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
params.unk_id = 2
|
||||
params.encoder_dim = 512
|
||||
params.rnn_hidden_size = 1024
|
||||
params.num_encoder_layers = 12
|
||||
params.aux_layer_period = 0
|
||||
params.exp_dir = Path("exp_test_model")
|
||||
|
||||
model = get_transducer_model(params)
|
||||
model.eval()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
torch.jit.script(model)
|
||||
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
|
||||
if not os.path.exists(params.exp_dir):
|
||||
os.path.mkdir(params.exp_dir)
|
||||
|
||||
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
|
||||
export_encoder_model_jit_trace(model.encoder, encoder_filename)
|
||||
|
||||
decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
|
||||
export_decoder_model_jit_trace(model.decoder, decoder_filename)
|
||||
|
||||
joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
|
||||
export_joiner_model_jit_trace(model.joiner, joiner_filename)
|
||||
|
||||
print("The model has been successfully exported using jit.trace.")
|
||||
|
||||
|
||||
def main():
|
||||
|
Loading…
x
Reference in New Issue
Block a user