262 lines
7.6 KiB
Python
Executable File

#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
"""
This script checks that exported ONNX models produce the same output
with the given torchscript model for the same input.
We use the pre-trained model from
https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt"
cd exp
ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt
popd
2. Export the model via torch.jit.trace()
./lstm_transducer_stateless2/export.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp/ \
--jit-trace 1
It will generate the following 3 files inside $repo/exp
- encoder_jit_trace.pt
- decoder_jit_trace.pt
- joiner_jit_trace.pt
3. Export the model to ONNX
./lstm_transducer_stateless2/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
4. Run this file
./lstm_transducer_stateless2/onnx_check.py \
--jit-encoder-filename $repo/exp/encoder_jit_trace.pt \
--jit-decoder-filename $repo/exp/decoder_jit_trace.pt \
--jit-joiner-filename $repo/exp/joiner_jit_trace.pt \
--onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx
"""
import argparse
import logging
from onnx_pretrained import OnnxModel
from icefall import is_module_available
import torch
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--jit-encoder-filename",
required=True,
type=str,
help="Path to the torchscript encoder model",
)
parser.add_argument(
"--jit-decoder-filename",
required=True,
type=str,
help="Path to the torchscript decoder model",
)
parser.add_argument(
"--jit-joiner-filename",
required=True,
type=str,
help="Path to the torchscript joiner model",
)
parser.add_argument(
"--onnx-encoder-filename",
required=True,
type=str,
help="Path to the ONNX encoder model",
)
parser.add_argument(
"--onnx-decoder-filename",
required=True,
type=str,
help="Path to the ONNX decoder model",
)
parser.add_argument(
"--onnx-joiner-filename",
required=True,
type=str,
help="Path to the ONNX joiner model",
)
return parser
def test_encoder(
torch_encoder_model: torch.jit.ScriptModule,
torch_encoder_proj_model: torch.jit.ScriptModule,
onnx_model: OnnxModel,
):
N = torch.randint(1, 100, size=(1,)).item()
T = onnx_model.segment
C = 80
x_lens = torch.tensor([T] * N)
torch_states = torch_encoder_model.get_init_states(N)
onnx_model.init_encoder_states(N)
for i in range(5):
logging.info(f"test_encoder: iter {i}")
x = torch.rand(N, T, C)
torch_encoder_out, _, torch_states = torch_encoder_model(
x, x_lens, torch_states
)
torch_encoder_out = torch_encoder_proj_model(torch_encoder_out)
onnx_encoder_out = onnx_model.run_encoder(x)
assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-4), (
(torch_encoder_out - onnx_encoder_out).abs().max()
)
def test_decoder(
torch_decoder_model: torch.jit.ScriptModule,
torch_decoder_proj_model: torch.jit.ScriptModule,
onnx_model: OnnxModel,
):
context_size = onnx_model.context_size
vocab_size = onnx_model.vocab_size
for i in range(10):
N = torch.randint(1, 100, size=(1,)).item()
logging.info(f"test_decoder: iter {i}, N={N}")
x = torch.randint(
low=1,
high=vocab_size,
size=(N, context_size),
dtype=torch.int64,
)
torch_decoder_out = torch_decoder_model(x, need_pad=torch.tensor([False]))
torch_decoder_out = torch_decoder_proj_model(torch_decoder_out)
torch_decoder_out = torch_decoder_out.squeeze(1)
onnx_decoder_out = onnx_model.run_decoder(x)
assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), (
(torch_decoder_out - onnx_decoder_out).abs().max()
)
def test_joiner(
torch_joiner_model: torch.jit.ScriptModule,
onnx_model: OnnxModel,
):
encoder_dim = torch_joiner_model.encoder_proj.weight.shape[1]
decoder_dim = torch_joiner_model.decoder_proj.weight.shape[1]
for i in range(10):
N = torch.randint(1, 100, size=(1,)).item()
logging.info(f"test_joiner: iter {i}, N={N}")
encoder_out = torch.rand(N, encoder_dim)
decoder_out = torch.rand(N, decoder_dim)
projected_encoder_out = torch_joiner_model.encoder_proj(encoder_out)
projected_decoder_out = torch_joiner_model.decoder_proj(decoder_out)
torch_joiner_out = torch_joiner_model(encoder_out, decoder_out)
onnx_joiner_out = onnx_model.run_joiner(
projected_encoder_out, projected_decoder_out
)
assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), (
(torch_joiner_out - onnx_joiner_out).abs().max()
)
@torch.no_grad()
def main():
args = get_parser().parse_args()
logging.info(vars(args))
torch_encoder_model = torch.jit.load(args.jit_encoder_filename)
torch_decoder_model = torch.jit.load(args.jit_decoder_filename)
torch_joiner_model = torch.jit.load(args.jit_joiner_filename)
onnx_model = OnnxModel(
encoder_model_filename=args.onnx_encoder_filename,
decoder_model_filename=args.onnx_decoder_filename,
joiner_model_filename=args.onnx_joiner_filename,
)
logging.info("Test encoder")
# When exporting the model to onnx, we have already put the encoder_proj
# inside the encoder.
test_encoder(torch_encoder_model, torch_joiner_model.encoder_proj, onnx_model)
logging.info("Test decoder")
# When exporting the model to onnx, we have already put the decoder_proj
# inside the decoder.
test_decoder(torch_decoder_model, torch_joiner_model.decoder_proj, onnx_model)
logging.info("Test joiner")
test_joiner(torch_joiner_model, onnx_model)
logging.info("Finished checking ONNX models")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
# See https://github.com/pytorch/pytorch/issues/38342
# and https://github.com/pytorch/pytorch/issues/33354
#
# If we don't do this, the delay increases whenever there is
# a new request that changes the actual batch size.
# If you use `py-spy dump --pid <server-pid> --native`, you will
# see a lot of time is spent in re-compiling the torch script model.
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._set_graph_executor_optimize(False)
if __name__ == "__main__":
torch.manual_seed(20230207)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()