Update CI to test ONNX models.

This commit is contained in:
Fangjun Kuang 2022-07-30 11:50:39 +08:00
parent c70df281c6
commit 71ea196370
6 changed files with 40 additions and 10 deletions

View File

@ -22,8 +22,35 @@ ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
ln -s pretrained-iter-1224000-avg-14.pt pretrained.pt
ln -s pretrained-iter-1224000-avg-14.pt epoch-99.pt
popd
log "Test exporting to ONNX format"
./pruned_transducer_stateless3/export.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--onnx 1
log "Export to torchscript model"
./pruned_transducer_stateless3/export.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--jit 1
ls -lh $repo/exp/*.onnx
ls -lh $repo/exp/*.pt
./pruned_transducer_stateless3/onnx_check.py \
--jit-filename $repo/exp/cpu_jit.pt \
--onnx-encoder-filename $repo/exp/encoder.onnx \
--onnx-decoder-filename $repo/exp/decoder.onnx \
--onnx-joiner-filename $repo/exp/joiner.onnx
for sym in 1 2 3; do
log "Greedy search with --max-sym-per-frame $sym"

View File

@ -35,7 +35,7 @@ on:
jobs:
run_librispeech_pruned_transducer_stateless3_2022_05_13:
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:

View File

@ -45,8 +45,8 @@ are on CPU. You can use `to("cuda")` to move them to a CUDA device.
--avg 10 \
--onnx 1
It will generate the following files in the given `exp_dir`.
See `onnx_check.py` to see how to use it.
It will generate the following three files in the given `exp_dir`.
Check `onnx_check.py` for how to use them.
- encoder.onnx
- decoder.onnx
@ -82,7 +82,6 @@ you can do:
import argparse
import logging
import onnx
from pathlib import Path
import sentencepiece as spm
@ -219,7 +218,7 @@ def export_encoder_model_onnx(
x = torch.zeros(1, 100, 80, dtype=torch.float32)
x_lens = torch.tensor([100], dtype=torch.int64)
# encoder_model = torch.jit.script(model.encoder)
# encoder_model = torch.jit.script(encoder_model)
# It throws the following error for the above statement
#
# RuntimeError: Exporting the operator __is_ to ONNX opset version
@ -257,7 +256,7 @@ def export_decoder_model_onnx(
The exported model has one input:
- y: a torch.int64 tensor of shape (N, 2)
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
and has one output:
@ -316,7 +315,7 @@ def export_joiner_model_onnx(
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
project_input = True
# Note: We use torch.jit.trace() here
# Note: It uses torch.jit.trace() internally
torch.onnx.export(
joiner_model,
(encoder_out, decoder_out, project_input),

View File

@ -25,12 +25,10 @@ import argparse
import logging
import onnxruntime as ort
import torch
ort.set_default_logger_severity(3)
import numpy as np
import torch
def get_parser():
parser = argparse.ArgumentParser(
@ -188,6 +186,7 @@ def main():
sess_options=options,
)
test_joiner(model, joiner_session)
logging.info("Finished checking ONNX models")
if __name__ == "__main__":

View File

@ -20,3 +20,6 @@ sentencepiece==0.1.96
tensorboard==2.8.0
typeguard==2.13.3
multi_quantization
onnx
onnxruntime

View File

@ -4,3 +4,5 @@ sentencepiece>=0.1.96
tensorboard
typeguard
multi_quantization
onnx
onnxruntime