Update CI to test ONNX models.
This commit is contained in:
parent
c70df281c6
commit
71ea196370
@ -22,8 +22,35 @@ ls -lh $repo/test_wavs/*.wav
|
|||||||
|
|
||||||
pushd $repo/exp
|
pushd $repo/exp
|
||||||
ln -s pretrained-iter-1224000-avg-14.pt pretrained.pt
|
ln -s pretrained-iter-1224000-avg-14.pt pretrained.pt
|
||||||
|
ln -s pretrained-iter-1224000-avg-14.pt epoch-99.pt
|
||||||
popd
|
popd
|
||||||
|
|
||||||
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/export.py \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--onnx 1
|
||||||
|
|
||||||
|
log "Export to torchscript model"
|
||||||
|
./pruned_transducer_stateless3/export.py \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--jit 1
|
||||||
|
|
||||||
|
ls -lh $repo/exp/*.onnx
|
||||||
|
ls -lh $repo/exp/*.pt
|
||||||
|
|
||||||
|
./pruned_transducer_stateless3/onnx_check.py \
|
||||||
|
--jit-filename $repo/exp/cpu_jit.pt \
|
||||||
|
--onnx-encoder-filename $repo/exp/encoder.onnx \
|
||||||
|
--onnx-decoder-filename $repo/exp/decoder.onnx \
|
||||||
|
--onnx-joiner-filename $repo/exp/joiner.onnx
|
||||||
|
|
||||||
for sym in 1 2 3; do
|
for sym in 1 2 3; do
|
||||||
log "Greedy search with --max-sym-per-frame $sym"
|
log "Greedy search with --max-sym-per-frame $sym"
|
||||||
|
|
||||||
|
|||||||
@ -35,7 +35,7 @@ on:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
run_librispeech_pruned_transducer_stateless3_2022_05_13:
|
run_librispeech_pruned_transducer_stateless3_2022_05_13:
|
||||||
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
|
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
|
|||||||
@ -45,8 +45,8 @@ are on CPU. You can use `to("cuda")` to move them to a CUDA device.
|
|||||||
--avg 10 \
|
--avg 10 \
|
||||||
--onnx 1
|
--onnx 1
|
||||||
|
|
||||||
It will generate the following files in the given `exp_dir`.
|
It will generate the following three files in the given `exp_dir`.
|
||||||
See `onnx_check.py` to see how to use it.
|
Check `onnx_check.py` for how to use them.
|
||||||
|
|
||||||
- encoder.onnx
|
- encoder.onnx
|
||||||
- decoder.onnx
|
- decoder.onnx
|
||||||
@ -82,7 +82,6 @@ you can do:
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import onnx
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
@ -219,7 +218,7 @@ def export_encoder_model_onnx(
|
|||||||
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||||
x_lens = torch.tensor([100], dtype=torch.int64)
|
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||||
|
|
||||||
# encoder_model = torch.jit.script(model.encoder)
|
# encoder_model = torch.jit.script(encoder_model)
|
||||||
# It throws the following error for the above statement
|
# It throws the following error for the above statement
|
||||||
#
|
#
|
||||||
# RuntimeError: Exporting the operator __is_ to ONNX opset version
|
# RuntimeError: Exporting the operator __is_ to ONNX opset version
|
||||||
@ -257,7 +256,7 @@ def export_decoder_model_onnx(
|
|||||||
|
|
||||||
The exported model has one input:
|
The exported model has one input:
|
||||||
|
|
||||||
- y: a torch.int64 tensor of shape (N, 2)
|
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||||
|
|
||||||
and has one output:
|
and has one output:
|
||||||
|
|
||||||
@ -316,7 +315,7 @@ def export_joiner_model_onnx(
|
|||||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||||
|
|
||||||
project_input = True
|
project_input = True
|
||||||
# Note: We use torch.jit.trace() here
|
# Note: It uses torch.jit.trace() internally
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
joiner_model,
|
joiner_model,
|
||||||
(encoder_out, decoder_out, project_input),
|
(encoder_out, decoder_out, project_input),
|
||||||
|
|||||||
@ -25,12 +25,10 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
|
import torch
|
||||||
|
|
||||||
ort.set_default_logger_severity(3)
|
ort.set_default_logger_severity(3)
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -188,6 +186,7 @@ def main():
|
|||||||
sess_options=options,
|
sess_options=options,
|
||||||
)
|
)
|
||||||
test_joiner(model, joiner_session)
|
test_joiner(model, joiner_session)
|
||||||
|
logging.info("Finished checking ONNX models")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -20,3 +20,6 @@ sentencepiece==0.1.96
|
|||||||
tensorboard==2.8.0
|
tensorboard==2.8.0
|
||||||
typeguard==2.13.3
|
typeguard==2.13.3
|
||||||
multi_quantization
|
multi_quantization
|
||||||
|
|
||||||
|
onnx
|
||||||
|
onnxruntime
|
||||||
|
|||||||
@ -4,3 +4,5 @@ sentencepiece>=0.1.96
|
|||||||
tensorboard
|
tensorboard
|
||||||
typeguard
|
typeguard
|
||||||
multi_quantization
|
multi_quantization
|
||||||
|
onnx
|
||||||
|
onnxruntime
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user