mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Update CI to test ONNX models.
This commit is contained in:
parent
c70df281c6
commit
71ea196370
@ -22,8 +22,35 @@ ls -lh $repo/test_wavs/*.wav
|
||||
|
||||
pushd $repo/exp
|
||||
ln -s pretrained-iter-1224000-avg-14.pt pretrained.pt
|
||||
ln -s pretrained-iter-1224000-avg-14.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
log "Test exporting to ONNX format"
|
||||
|
||||
./pruned_transducer_stateless3/export.py \
|
||||
--exp-dir $repo/exp \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--onnx 1
|
||||
|
||||
log "Export to torchscript model"
|
||||
./pruned_transducer_stateless3/export.py \
|
||||
--exp-dir $repo/exp \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--jit 1
|
||||
|
||||
ls -lh $repo/exp/*.onnx
|
||||
ls -lh $repo/exp/*.pt
|
||||
|
||||
./pruned_transducer_stateless3/onnx_check.py \
|
||||
--jit-filename $repo/exp/cpu_jit.pt \
|
||||
--onnx-encoder-filename $repo/exp/encoder.onnx \
|
||||
--onnx-decoder-filename $repo/exp/decoder.onnx \
|
||||
--onnx-joiner-filename $repo/exp/joiner.onnx
|
||||
|
||||
for sym in 1 2 3; do
|
||||
log "Greedy search with --max-sym-per-frame $sym"
|
||||
|
||||
|
@ -35,7 +35,7 @@ on:
|
||||
|
||||
jobs:
|
||||
run_librispeech_pruned_transducer_stateless3_2022_05_13:
|
||||
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
|
@ -45,8 +45,8 @@ are on CPU. You can use `to("cuda")` to move them to a CUDA device.
|
||||
--avg 10 \
|
||||
--onnx 1
|
||||
|
||||
It will generate the following files in the given `exp_dir`.
|
||||
See `onnx_check.py` to see how to use it.
|
||||
It will generate the following three files in the given `exp_dir`.
|
||||
Check `onnx_check.py` for how to use them.
|
||||
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
@ -82,7 +82,6 @@ you can do:
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import onnx
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
@ -219,7 +218,7 @@ def export_encoder_model_onnx(
|
||||
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||
|
||||
# encoder_model = torch.jit.script(model.encoder)
|
||||
# encoder_model = torch.jit.script(encoder_model)
|
||||
# It throws the following error for the above statement
|
||||
#
|
||||
# RuntimeError: Exporting the operator __is_ to ONNX opset version
|
||||
@ -257,7 +256,7 @@ def export_decoder_model_onnx(
|
||||
|
||||
The exported model has one input:
|
||||
|
||||
- y: a torch.int64 tensor of shape (N, 2)
|
||||
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||
|
||||
and has one output:
|
||||
|
||||
@ -316,7 +315,7 @@ def export_joiner_model_onnx(
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
|
||||
project_input = True
|
||||
# Note: We use torch.jit.trace() here
|
||||
# Note: It uses torch.jit.trace() internally
|
||||
torch.onnx.export(
|
||||
joiner_model,
|
||||
(encoder_out, decoder_out, project_input),
|
||||
|
@ -25,12 +25,10 @@ import argparse
|
||||
import logging
|
||||
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
|
||||
ort.set_default_logger_severity(3)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -188,6 +186,7 @@ def main():
|
||||
sess_options=options,
|
||||
)
|
||||
test_joiner(model, joiner_session)
|
||||
logging.info("Finished checking ONNX models")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -20,3 +20,6 @@ sentencepiece==0.1.96
|
||||
tensorboard==2.8.0
|
||||
typeguard==2.13.3
|
||||
multi_quantization
|
||||
|
||||
onnx
|
||||
onnxruntime
|
||||
|
@ -4,3 +4,5 @@ sentencepiece>=0.1.96
|
||||
tensorboard
|
||||
typeguard
|
||||
multi_quantization
|
||||
onnx
|
||||
onnxruntime
|
||||
|
Loading…
x
Reference in New Issue
Block a user