Remove all-in-one for onnx export (#614)

* Remove all-in-one for onnx export

* Exit on error for CI
This commit is contained in:
Fangjun Kuang 2022-10-12 10:34:06 +08:00 committed by GitHub
parent f3db4ea871
commit 1c07d2fb37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 636 additions and 465 deletions

View File

@ -4,6 +4,8 @@
# The computed features are saved to ~/tmp/fbank-libri and are
# cached for later runs
set -e
export PYTHONPATH=$PWD:$PYTHONPATH
echo $PYTHONPATH

View File

@ -6,6 +6,8 @@
# You will find directories `~/tmp/giga-dev-dataset-fbank` after running
# this script.
set -e
mkdir -p ~/tmp
cd ~/tmp

View File

@ -7,6 +7,8 @@
# You will find directories ~/tmp/download/LibriSpeech after running
# this script.
set -e
mkdir ~/tmp/download
cd egs/librispeech/ASR
ln -s ~/tmp/download .

View File

@ -3,6 +3,8 @@
# This script installs kaldifeat into the directory ~/tmp/kaldifeat
# which is cached by GitHub actions for later runs.
set -e
mkdir -p ~/tmp
cd ~/tmp
git clone https://github.com/csukuangfj/kaldifeat

View File

@ -4,6 +4,8 @@
# to egs/librispeech/ASR/download/LibriSpeech and generates manifest
# files in egs/librispeech/ASR/data/manifests
set -e
cd egs/librispeech/ASR
[ ! -e download ] && ln -s ~/tmp/download .
mkdir -p data/manifests

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}

View File

@ -1,4 +1,6 @@
#!/usr/bin/env bash
#
set -e
log() {
# This function is from espnet

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
@ -62,15 +64,13 @@ log "Decode with ONNX models"
--onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \
--onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx
./pruned_transducer_stateless3/onnx_check_all_in_one.py \
--jit-filename $repo/exp/cpu_jit.pt \
--onnx-all-in-one-filename $repo/exp/all_in_one.onnx
./pruned_transducer_stateless3/onnx_pretrained.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--encoder-model-filename $repo/exp/encoder.onnx \
--decoder-model-filename $repo/exp/decoder.onnx \
--joiner-model-filename $repo/exp/joiner.onnx \
--joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \
--joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}

View File

@ -1,5 +1,7 @@
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}

View File

@ -476,8 +476,8 @@ class ConformerEncoderLayer(nn.Module):
self,
src: Tensor,
pos_emb: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
src_mask: Optional[Tensor] = None,
warmup: float = 1.0,
) -> Tensor:
"""
@ -486,8 +486,8 @@ class ConformerEncoderLayer(nn.Module):
Args:
src: the sequence to the encoder layer (required).
pos_emb: Positional embedding tensor (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
src_mask: the mask for the src sequence (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
Shape:
@ -663,8 +663,8 @@ class ConformerEncoder(nn.Module):
self,
src: Tensor,
pos_emb: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
warmup: float = 1.0,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
@ -672,8 +672,8 @@ class ConformerEncoder(nn.Module):
Args:
src: the sequence to the encoder (required).
pos_emb: Positional embedding tensor (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
mask: the mask for the src sequence (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.

View File

@ -62,7 +62,7 @@ It will generates 3 files: `encoder_jit_trace.pt`,
--avg 10 \
--onnx 1
It will generate the following six files in the given `exp_dir`.
It will generate the following files in the given `exp_dir`.
Check `onnx_check.py` for how to use them.
- encoder.onnx
@ -70,8 +70,8 @@ Check `onnx_check.py` for how to use them.
- joiner.onnx
- joiner_encoder_proj.onnx
- joiner_decoder_proj.onnx
- all_in_one.onnx
Please see ./onnx_pretrained.py for usage of the generated files
(4) Export `model.state_dict()`
@ -118,8 +118,6 @@ import argparse
import logging
from pathlib import Path
import onnx_graphsurgeon as gs
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
@ -217,16 +215,15 @@ def get_parser():
type=str2bool,
default=False,
help="""If True, --jit is ignored and it exports the model
to onnx format. Three files will be generated:
to onnx format. It will generate the following files:
- encoder.onnx
- decoder.onnx
- joiner.onnx
- joiner_encoder_proj.onnx
- joiner_decoder_proj.onnx
- all_in_one.onnx
Check ./onnx_check.py and ./onnx_pretrained.py for how to use them.
Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them.
""",
)
@ -483,134 +480,99 @@ def export_joiner_model_onnx(
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported model has two inputs:
The exported joiner model has two inputs:
- projected_encoder_out: a tensor of shape (N, joiner_dim)
- projected_decoder_out: a tensor of shape (N, joiner_dim)
and produces one output:
- logit: a tensor of shape (N, vocab_size)
The exported encoder_proj model has one input:
- encoder_out: a tensor of shape (N, encoder_out_dim)
and produces one output:
- projected_encoder_out: a tensor of shape (N, joiner_dim)
The exported decoder_proj model has one input:
- decoder_out: a tensor of shape (N, decoder_out_dim)
and has one output:
- joiner_out: a tensor of shape (N, vocab_size)
and produces one output:
- projected_decoder_out: a tensor of shape (N, joiner_dim)
"""
encoder_proj_filename = str(joiner_filename).replace(
".onnx", "_encoder_proj.onnx"
)
decoder_proj_filename = str(joiner_filename).replace(
".onnx", "_decoder_proj.onnx"
)
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
encoder_out = torch.rand(1, 1, 1, encoder_out_dim, dtype=torch.float32)
decoder_out = torch.rand(1, 1, 1, decoder_out_dim, dtype=torch.float32)
joiner_dim = joiner_model.decoder_proj.weight.shape[0]
project_input = True
projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
project_input = False
# Note: It uses torch.jit.trace() internally
torch.onnx.export(
joiner_model,
(encoder_out, decoder_out, project_input),
(projected_encoder_out, projected_decoder_out, project_input),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=["encoder_out", "decoder_out", "project_input"],
input_names=[
"projected_encoder_out",
"projected_decoder_out",
"project_input",
],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"projected_encoder_out": {0: "N"},
"projected_decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
torch.onnx.export(
joiner_model.encoder_proj,
(encoder_out.squeeze(0).squeeze(0)),
str(joiner_filename).replace(".onnx", "_encoder_proj.onnx"),
verbose=False,
opset_version=opset_version,
input_names=["encoder_out"],
output_names=["encoder_proj"],
dynamic_axes={
"encoder_out": {0: "N"},
"encoder_proj": {0: "N"},
},
)
torch.onnx.export(
joiner_model.decoder_proj,
(decoder_out.squeeze(0).squeeze(0)),
str(joiner_filename).replace(".onnx", "_decoder_proj.onnx"),
verbose=False,
opset_version=opset_version,
input_names=["decoder_out"],
output_names=["decoder_proj"],
dynamic_axes={
"decoder_out": {0: "N"},
"decoder_proj": {0: "N"},
},
)
logging.info(f"Saved to {joiner_filename}")
def add_variables(
model: nn.Module, combined_model: onnx.ModelProto
) -> onnx.ModelProto:
graph = gs.import_onnx(combined_model)
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
node = gs.Node(
op="Identity",
name="constants_lm",
attrs={
"blank_id": blank_id,
"unk_id": unk_id,
"context_size": context_size,
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model.encoder_proj,
encoder_out,
encoder_proj_filename,
verbose=False,
opset_version=opset_version,
input_names=["encoder_out"],
output_names=["projected_encoder_out"],
dynamic_axes={
"encoder_out": {0: "N"},
"projected_encoder_out": {0: "N"},
},
inputs=[],
outputs=[],
)
graph.nodes.append(node)
logging.info(f"Saved to {encoder_proj_filename}")
graph = gs.export_onnx(graph)
return graph
def export_all_in_one_onnx(
model: nn.Module,
encoder_filename: str,
decoder_filename: str,
joiner_filename: str,
all_in_one_filename: str,
):
encoder_onnx = onnx.load(encoder_filename)
decoder_onnx = onnx.load(decoder_filename)
joiner_onnx = onnx.load(joiner_filename)
joiner_encoder_proj_onnx = onnx.load(
str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model.decoder_proj,
decoder_out,
decoder_proj_filename,
verbose=False,
opset_version=opset_version,
input_names=["decoder_out"],
output_names=["projected_decoder_out"],
dynamic_axes={
"decoder_out": {0: "N"},
"projected_decoder_out": {0: "N"},
},
)
joiner_decoder_proj_onnx = onnx.load(
str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
)
encoder_onnx = onnx.compose.add_prefix(encoder_onnx, prefix="encoder/")
decoder_onnx = onnx.compose.add_prefix(decoder_onnx, prefix="decoder/")
joiner_onnx = onnx.compose.add_prefix(joiner_onnx, prefix="joiner/")
joiner_encoder_proj_onnx = onnx.compose.add_prefix(
joiner_encoder_proj_onnx, prefix="joiner_encoder_proj/"
)
joiner_decoder_proj_onnx = onnx.compose.add_prefix(
joiner_decoder_proj_onnx, prefix="joiner_decoder_proj/"
)
combined_model = onnx.compose.merge_models(
encoder_onnx, decoder_onnx, io_map={}
)
combined_model = onnx.compose.merge_models(
combined_model, joiner_onnx, io_map={}
)
combined_model = onnx.compose.merge_models(
combined_model, joiner_encoder_proj_onnx, io_map={}
)
combined_model = onnx.compose.merge_models(
combined_model, joiner_decoder_proj_onnx, io_map={}
)
combined_model = add_variables(model, combined_model)
onnx.save(combined_model, all_in_one_filename)
logging.info(f"Saved to {all_in_one_filename}")
logging.info(f"Saved to {decoder_proj_filename}")
@torch.no_grad()
@ -704,15 +666,6 @@ def main():
joiner_filename,
opset_version=opset_version,
)
all_in_one_filename = params.exp_dir / "all_in_one.onnx"
export_all_in_one_onnx(
model,
encoder_filename,
decoder_filename,
joiner_filename,
all_in_one_filename,
)
elif params.jit is True:
convert_scaled_to_non_scaled(model, inplace=True)
logging.info("Using torch.jit.script()")

View File

@ -84,11 +84,13 @@ def test_encoder(
model: torch.jit.ScriptModule,
encoder_session: ort.InferenceSession,
):
encoder_inputs = encoder_session.get_inputs()
assert encoder_inputs[0].name == "x"
assert encoder_inputs[1].name == "x_lens"
assert encoder_inputs[0].shape == ["N", "T", 80]
assert encoder_inputs[1].shape == ["N"]
inputs = encoder_session.get_inputs()
outputs = encoder_session.get_outputs()
input_names = [n.name for n in inputs]
output_names = [n.name for n in outputs]
assert inputs[0].shape == ["N", "T", 80]
assert inputs[1].shape == ["N"]
for N in [1, 5]:
for T in [12, 25]:
@ -98,11 +100,11 @@ def test_encoder(
x_lens[0] = T
encoder_inputs = {
"x": x.numpy(),
"x_lens": x_lens.numpy(),
input_names[0]: x.numpy(),
input_names[1]: x_lens.numpy(),
}
encoder_out, encoder_out_lens = encoder_session.run(
["encoder_out", "encoder_out_lens"],
output_names,
encoder_inputs,
)
@ -110,7 +112,9 @@ def test_encoder(
encoder_out = torch.from_numpy(encoder_out)
assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), (
(encoder_out - torch_encoder_out).abs().max()
(encoder_out - torch_encoder_out).abs().max(),
encoder_out.shape,
torch_encoder_out.shape,
)
@ -118,15 +122,18 @@ def test_decoder(
model: torch.jit.ScriptModule,
decoder_session: ort.InferenceSession,
):
decoder_inputs = decoder_session.get_inputs()
assert decoder_inputs[0].name == "y"
assert decoder_inputs[0].shape == ["N", 2]
inputs = decoder_session.get_inputs()
outputs = decoder_session.get_outputs()
input_names = [n.name for n in inputs]
output_names = [n.name for n in outputs]
assert inputs[0].shape == ["N", 2]
for N in [1, 5, 10]:
y = torch.randint(low=1, high=500, size=(10, 2))
decoder_inputs = {"y": y.numpy()}
decoder_inputs = {input_names[0]: y.numpy()}
decoder_out = decoder_session.run(
["decoder_out"],
output_names,
decoder_inputs,
)[0]
decoder_out = torch.from_numpy(decoder_out)
@ -144,51 +151,62 @@ def test_joiner(
joiner_decoder_proj_session: ort.InferenceSession,
):
joiner_inputs = joiner_session.get_inputs()
assert joiner_inputs[0].name == "encoder_out"
assert joiner_inputs[0].shape == ["N", 1, 1, 512]
joiner_outputs = joiner_session.get_outputs()
joiner_input_names = [n.name for n in joiner_inputs]
joiner_output_names = [n.name for n in joiner_outputs]
assert joiner_inputs[1].name == "decoder_out"
assert joiner_inputs[1].shape == ["N", 1, 1, 512]
assert joiner_inputs[0].shape == ["N", 512]
assert joiner_inputs[1].shape == ["N", 512]
joiner_encoder_proj_inputs = joiner_encoder_proj_session.get_inputs()
assert joiner_encoder_proj_inputs[0].name == "encoder_out"
encoder_proj_input_name = joiner_encoder_proj_inputs[0].name
assert joiner_encoder_proj_inputs[0].shape == ["N", 512]
joiner_encoder_proj_outputs = joiner_encoder_proj_session.get_outputs()
encoder_proj_output_name = joiner_encoder_proj_outputs[0].name
joiner_decoder_proj_inputs = joiner_decoder_proj_session.get_inputs()
assert joiner_decoder_proj_inputs[0].name == "decoder_out"
decoder_proj_input_name = joiner_decoder_proj_inputs[0].name
assert joiner_decoder_proj_inputs[0].shape == ["N", 512]
joiner_decoder_proj_outputs = joiner_decoder_proj_session.get_outputs()
decoder_proj_output_name = joiner_decoder_proj_outputs[0].name
for N in [1, 5, 10]:
encoder_out = torch.rand(N, 1, 1, 512)
decoder_out = torch.rand(N, 1, 1, 512)
encoder_out = torch.rand(N, 512)
decoder_out = torch.rand(N, 512)
projected_encoder_out = torch.rand(N, 512)
projected_decoder_out = torch.rand(N, 512)
joiner_inputs = {
"encoder_out": encoder_out.numpy(),
"decoder_out": decoder_out.numpy(),
joiner_input_names[0]: projected_encoder_out.numpy(),
joiner_input_names[1]: projected_decoder_out.numpy(),
}
joiner_out = joiner_session.run(["logit"], joiner_inputs)[0]
joiner_out = joiner_session.run(joiner_output_names, joiner_inputs)[0]
joiner_out = torch.from_numpy(joiner_out)
torch_joiner_out = model.joiner(
encoder_out,
decoder_out,
project_input=True,
projected_encoder_out,
projected_decoder_out,
project_input=False,
)
assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), (
(joiner_out - torch_joiner_out).abs().max()
)
# Now test encoder_proj
joiner_encoder_proj_inputs = {
"encoder_out": encoder_out.squeeze(1).squeeze(1).numpy()
encoder_proj_input_name: encoder_out.numpy()
}
joiner_encoder_proj_out = joiner_encoder_proj_session.run(
["encoder_proj"], joiner_encoder_proj_inputs
[encoder_proj_output_name], joiner_encoder_proj_inputs
)[0]
joiner_encoder_proj_out = torch.from_numpy(joiner_encoder_proj_out)
torch_joiner_encoder_proj_out = model.joiner.encoder_proj(
encoder_out.squeeze(1).squeeze(1)
)
torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out)
assert torch.allclose(
joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5
), (
@ -197,17 +215,16 @@ def test_joiner(
.max()
)
# Now test decoder_proj
joiner_decoder_proj_inputs = {
"decoder_out": decoder_out.squeeze(1).squeeze(1).numpy()
decoder_proj_input_name: decoder_out.numpy()
}
joiner_decoder_proj_out = joiner_decoder_proj_session.run(
["decoder_proj"], joiner_decoder_proj_inputs
[decoder_proj_output_name], joiner_decoder_proj_inputs
)[0]
joiner_decoder_proj_out = torch.from_numpy(joiner_decoder_proj_out)
torch_joiner_decoder_proj_out = model.joiner.decoder_proj(
decoder_out.squeeze(1).squeeze(1)
)
torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out)
assert torch.allclose(
joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5
), (

View File

@ -1,284 +0,0 @@
#!/usr/bin/env python3
#
# Copyright 2022 Xiaomi Corporation (Author: Yunus Emre Ozkose)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script checks that exported onnx models produce the same output
with the given torchscript model for the same input.
"""
import argparse
import logging
import os
import onnx
import onnx_graphsurgeon as gs
import onnxruntime
import onnxruntime as ort
import torch
ort.set_default_logger_severity(3)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--jit-filename",
required=True,
type=str,
help="Path to the torchscript model",
)
parser.add_argument(
"--onnx-all-in-one-filename",
required=True,
type=str,
help="Path to the onnx all in one model",
)
return parser
def test_encoder(
model: torch.jit.ScriptModule,
encoder_session: ort.InferenceSession,
):
encoder_inputs = encoder_session.get_inputs()
assert encoder_inputs[0].shape == ["N", "T", 80]
assert encoder_inputs[1].shape == ["N"]
encoder_input_names = [i.name for i in encoder_inputs]
encoder_output_names = [i.name for i in encoder_session.get_outputs()]
for N in [1, 5]:
for T in [12, 25]:
print("N, T", N, T)
x = torch.rand(N, T, 80, dtype=torch.float32)
x_lens = torch.randint(low=10, high=T + 1, size=(N,))
x_lens[0] = T
encoder_inputs = {
encoder_input_names[0]: x.numpy(),
encoder_input_names[1]: x_lens.numpy(),
}
encoder_out, encoder_out_lens = encoder_session.run(
[encoder_output_names[1], encoder_output_names[0]],
encoder_inputs,
)
torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens)
encoder_out = torch.from_numpy(encoder_out)
assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), (
(encoder_out - torch_encoder_out).abs().max()
)
def test_decoder(
model: torch.jit.ScriptModule,
decoder_session: ort.InferenceSession,
):
decoder_inputs = decoder_session.get_inputs()
assert decoder_inputs[0].shape == ["N", 2]
decoder_input_names = [i.name for i in decoder_inputs]
decoder_output_names = [i.name for i in decoder_session.get_outputs()]
for N in [1, 5, 10]:
y = torch.randint(low=1, high=500, size=(10, 2))
decoder_inputs = {decoder_input_names[0]: y.numpy()}
decoder_out = decoder_session.run(
[decoder_output_names[0]],
decoder_inputs,
)[0]
decoder_out = torch.from_numpy(decoder_out)
torch_decoder_out = model.decoder(y, need_pad=False)
assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), (
(decoder_out - torch_decoder_out).abs().max()
)
def test_joiner(
model: torch.jit.ScriptModule,
joiner_session: ort.InferenceSession,
):
joiner_inputs = joiner_session.get_inputs()
assert joiner_inputs[0].shape == ["N", 512]
assert joiner_inputs[1].shape == ["N", 512]
joiner_input_names = [i.name for i in joiner_inputs]
joiner_output_names = [i.name for i in joiner_session.get_outputs()]
for N in [1, 5, 10]:
encoder_out = torch.rand(N, 512)
decoder_out = torch.rand(N, 512)
joiner_inputs = {
joiner_input_names[0]: encoder_out.numpy(),
joiner_input_names[1]: decoder_out.numpy(),
}
joiner_out = joiner_session.run(
[joiner_output_names[0]], joiner_inputs
)[0]
joiner_out = torch.from_numpy(joiner_out)
torch_joiner_out = model.joiner(
encoder_out,
decoder_out,
project_input=True,
)
assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), (
(joiner_out - torch_joiner_out).abs().max()
)
def extract_sub_model(
onnx_graph: onnx.ModelProto,
input_op_names: list,
output_op_names: list,
non_verbose=False,
):
onnx_graph = onnx.shape_inference.infer_shapes(onnx_graph)
graph = gs.import_onnx(onnx_graph)
graph.cleanup().toposort()
# Extraction of input OP and output OP
graph_node_inputs = [
graph_nodes
for graph_nodes in graph.nodes
for graph_nodes_input in graph_nodes.inputs
if graph_nodes_input.name in input_op_names
]
graph_node_outputs = [
graph_nodes
for graph_nodes in graph.nodes
for graph_nodes_output in graph_nodes.outputs
if graph_nodes_output.name in output_op_names
]
# Init graph INPUT/OUTPUT
graph.inputs.clear()
graph.outputs.clear()
# Update graph INPUT/OUTPUT
graph.inputs = [
graph_node_input
for graph_node in graph_node_inputs
for graph_node_input in graph_node.inputs
if graph_node_input.shape
]
graph.outputs = [
graph_node_output
for graph_node in graph_node_outputs
for graph_node_output in graph_node.outputs
]
# Cleanup
graph.cleanup().toposort()
# Shape Estimation
extracted_graph = None
try:
extracted_graph = onnx.shape_inference.infer_shapes(
gs.export_onnx(graph)
)
except Exception:
extracted_graph = gs.export_onnx(graph)
if not non_verbose:
print(
"WARNING: "
+ "The input shape of the next OP does not match the output shape. "
+ "Be sure to open the .onnx file to verify the certainty of the geometry."
)
return extracted_graph
def extract_encoder(onnx_model: onnx.ModelProto):
encoder_ = extract_sub_model(
onnx_model,
["encoder/x", "encoder/x_lens"],
["encoder/encoder_out", "encoder/encoder_out_lens"],
False,
)
onnx.save(encoder_, "tmp_encoder.onnx")
onnx.checker.check_model(encoder_)
sess = onnxruntime.InferenceSession("tmp_encoder.onnx")
os.remove("tmp_encoder.onnx")
return sess
def extract_decoder(onnx_model: onnx.ModelProto):
decoder_ = extract_sub_model(
onnx_model, ["decoder/y"], ["decoder/decoder_out"], False
)
onnx.save(decoder_, "tmp_decoder.onnx")
onnx.checker.check_model(decoder_)
sess = onnxruntime.InferenceSession("tmp_decoder.onnx")
os.remove("tmp_decoder.onnx")
return sess
def extract_joiner(onnx_model: onnx.ModelProto):
joiner_ = extract_sub_model(
onnx_model,
["joiner/encoder_out", "joiner/decoder_out"],
["joiner/logit"],
False,
)
onnx.save(joiner_, "tmp_joiner.onnx")
onnx.checker.check_model(joiner_)
sess = onnxruntime.InferenceSession("tmp_joiner.onnx")
os.remove("tmp_joiner.onnx")
return sess
@torch.no_grad()
def main():
args = get_parser().parse_args()
logging.info(vars(args))
model = torch.jit.load(args.jit_filename)
onnx_model = onnx.load(args.onnx_all_in_one_filename)
options = ort.SessionOptions()
options.inter_op_num_threads = 1
options.intra_op_num_threads = 1
logging.info("Test encoder")
encoder_session = extract_encoder(onnx_model)
test_encoder(model, encoder_session)
logging.info("Test decoder")
decoder_session = extract_decoder(onnx_model)
test_decoder(model, decoder_session)
logging.info("Test joiner")
joiner_session = extract_joiner(onnx_model)
test_joiner(model, joiner_session)
logging.info("Finished checking ONNX models")
if __name__ == "__main__":
torch.manual_seed(20220727)
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -59,21 +59,35 @@ def get_parser():
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder torchscript model. ",
help="Path to the encoder onnx model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder torchscript model. ",
help="Path to the decoder onnx model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner torchscript model. ",
help="Path to the joiner onnx model. ",
)
parser.add_argument(
"--joiner-encoder-proj-model-filename",
type=str,
required=True,
help="Path to the joiner encoder_proj onnx model. ",
)
parser.add_argument(
"--joiner-decoder-proj-model-filename",
type=str,
required=True,
help="Path to the joiner decoder_proj onnx model. ",
)
parser.add_argument(
@ -136,6 +150,8 @@ def read_sound_files(
def greedy_search(
decoder: ort.InferenceSession,
joiner: ort.InferenceSession,
joiner_encoder_proj: ort.InferenceSession,
joiner_decoder_proj: ort.InferenceSession,
encoder_out: np.ndarray,
encoder_out_lens: np.ndarray,
context_size: int,
@ -146,6 +162,10 @@ def greedy_search(
The decoder model.
joiner:
The joiner model.
joiner_encoder_proj:
The joiner encoder projection model.
joiner_decoder_proj:
The joiner decoder projection model.
encoder_out:
A 3-D tensor of shape (N, T, C)
encoder_out_lens:
@ -167,6 +187,15 @@ def greedy_search(
enforce_sorted=False,
)
projected_encoder_out = joiner_encoder_proj.run(
[joiner_encoder_proj.get_outputs()[0].name],
{
joiner_encoder_proj.get_inputs()[
0
].name: packed_encoder_out.data.numpy()
},
)[0]
blank_id = 0 # hard-code to 0
batch_size_list = packed_encoder_out.batch_sizes.tolist()
@ -194,30 +223,28 @@ def greedy_search(
decoder_input_nodes[0].name: decoder_input.numpy(),
},
)[0].squeeze(1)
decoder_out = torch.from_numpy(decoder_out)
projected_decoder_out = joiner_decoder_proj.run(
[joiner_decoder_proj.get_outputs()[0].name],
{joiner_decoder_proj.get_inputs()[0].name: decoder_out},
)[0]
projected_decoder_out = torch.from_numpy(projected_decoder_out)
offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = packed_encoder_out.data[start:end]
current_encoder_out = current_encoder_out
current_encoder_out = projected_encoder_out[start:end]
# current_encoder_out's shape: (batch_size, encoder_out_dim)
offset = end
decoder_out = decoder_out[:batch_size]
projected_decoder_out = projected_decoder_out[:batch_size]
logits = joiner.run(
[joiner_output_nodes[0].name],
{
joiner_input_nodes[0]
.name: current_encoder_out.unsqueeze(1)
.unsqueeze(1)
.numpy(),
joiner_input_nodes[1]
.name: decoder_out.unsqueeze(1)
.unsqueeze(1)
.numpy(),
joiner_input_nodes[0].name: current_encoder_out,
joiner_input_nodes[1].name: projected_decoder_out.numpy(),
},
)[0]
logits = torch.from_numpy(logits).squeeze(1).squeeze(1)
@ -243,7 +270,11 @@ def greedy_search(
decoder_input_nodes[0].name: decoder_input.numpy(),
},
)[0].squeeze(1)
decoder_out = torch.from_numpy(decoder_out)
projected_decoder_out = joiner_decoder_proj.run(
[joiner_decoder_proj.get_outputs()[0].name],
{joiner_decoder_proj.get_inputs()[0].name: decoder_out},
)[0]
projected_decoder_out = torch.from_numpy(projected_decoder_out)
sorted_ans = [h[context_size:] for h in hyps]
ans = []
@ -279,6 +310,16 @@ def main():
sess_options=session_opts,
)
joiner_encoder_proj = ort.InferenceSession(
args.joiner_encoder_proj_model_filename,
sess_options=session_opts,
)
joiner_decoder_proj = ort.InferenceSession(
args.joiner_decoder_proj_model_filename,
sess_options=session_opts,
)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
@ -323,6 +364,8 @@ def main():
hyps = greedy_search(
decoder=decoder,
joiner=joiner,
joiner_encoder_proj=joiner_encoder_proj,
joiner_decoder_proj=joiner_decoder_proj,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
context_size=args.context_size,

View File

@ -0,0 +1,401 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file is to test that models can be exported to onnx.
"""
import os
import onnxruntime as ort
import torch
from conformer import (
Conformer,
ConformerEncoder,
ConformerEncoderLayer,
Conv2dSubsampling,
RelPositionalEncoding,
)
from scaling_converter import convert_scaled_to_non_scaled
from icefall.utils import make_pad_mask
ort.set_default_logger_severity(3)
def test_conv2d_subsampling():
filename = "conv2d_subsampling.onnx"
opset_version = 11
N = 30
T = 50
num_features = 80
d_model = 512
x = torch.rand(N, T, num_features)
encoder_embed = Conv2dSubsampling(num_features, d_model)
encoder_embed.eval()
encoder_embed = convert_scaled_to_non_scaled(encoder_embed, inplace=True)
jit_model = torch.jit.trace(encoder_embed, x)
torch.onnx.export(
encoder_embed,
x,
filename,
verbose=False,
opset_version=opset_version,
input_names=["x"],
output_names=["y"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"y": {0: "N", 1: "T"},
},
)
options = ort.SessionOptions()
options.inter_op_num_threads = 1
options.intra_op_num_threads = 1
session = ort.InferenceSession(
filename,
sess_options=options,
)
input_nodes = session.get_inputs()
assert input_nodes[0].name == "x"
assert input_nodes[0].shape == ["N", "T", num_features]
inputs = {input_nodes[0].name: x.numpy()}
onnx_y = session.run(["y"], inputs)[0]
onnx_y = torch.from_numpy(onnx_y)
torch_y = jit_model(x)
assert torch.allclose(onnx_y, torch_y, atol=1e-05), (
(onnx_y - torch_y).abs().max()
)
os.remove(filename)
def test_rel_pos():
filename = "rel_pos.onnx"
opset_version = 11
N = 30
T = 50
num_features = 80
d_model = 512
x = torch.rand(N, T, num_features)
encoder_pos = RelPositionalEncoding(d_model, dropout_rate=0.1)
encoder_pos.eval()
encoder_pos = convert_scaled_to_non_scaled(encoder_pos, inplace=True)
jit_model = torch.jit.trace(encoder_pos, x)
torch.onnx.export(
encoder_pos,
x,
filename,
verbose=False,
opset_version=opset_version,
input_names=["x"],
output_names=["y", "pos_emb"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"y": {0: "N", 1: "T"},
"pos_emb": {0: "N", 1: "T"},
},
)
options = ort.SessionOptions()
options.inter_op_num_threads = 1
options.intra_op_num_threads = 1
session = ort.InferenceSession(
filename,
sess_options=options,
)
input_nodes = session.get_inputs()
assert input_nodes[0].name == "x"
assert input_nodes[0].shape == ["N", "T", num_features]
inputs = {input_nodes[0].name: x.numpy()}
onnx_y, onnx_pos_emb = session.run(["y", "pos_emb"], inputs)
onnx_y = torch.from_numpy(onnx_y)
onnx_pos_emb = torch.from_numpy(onnx_pos_emb)
torch_y, torch_pos_emb = jit_model(x)
assert torch.allclose(onnx_y, torch_y, atol=1e-05), (
(onnx_y - torch_y).abs().max()
)
assert torch.allclose(onnx_pos_emb, torch_pos_emb, atol=1e-05), (
(onnx_pos_emb - torch_pos_emb).abs().max()
)
print(onnx_y.abs().sum(), torch_y.abs().sum())
print(onnx_pos_emb.abs().sum(), torch_pos_emb.abs().sum())
os.remove(filename)
def test_conformer_encoder_layer():
filename = "conformer_encoder_layer.onnx"
opset_version = 11
N = 30
T = 50
d_model = 512
nhead = 8
dim_feedforward = 2048
dropout = 0.1
layer_dropout = 0.075
cnn_module_kernel = 31
causal = False
x = torch.rand(N, T, d_model)
x_lens = torch.full((N,), fill_value=T, dtype=torch.int64)
src_key_padding_mask = make_pad_mask(x_lens)
encoder_pos = RelPositionalEncoding(d_model, dropout)
encoder_pos.eval()
encoder_pos = convert_scaled_to_non_scaled(encoder_pos, inplace=True)
x, pos_emb = encoder_pos(x)
x = x.permute(1, 0, 2)
encoder_layer = ConformerEncoderLayer(
d_model,
nhead,
dim_feedforward,
dropout,
layer_dropout,
cnn_module_kernel,
causal,
)
encoder_layer.eval()
encoder_layer = convert_scaled_to_non_scaled(encoder_layer, inplace=True)
jit_model = torch.jit.trace(
encoder_layer, (x, pos_emb, src_key_padding_mask)
)
torch.onnx.export(
encoder_layer,
(x, pos_emb, src_key_padding_mask),
filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "pos_emb", "src_key_padding_mask"],
output_names=["y"],
dynamic_axes={
"x": {0: "T", 1: "N"},
"pos_emb": {0: "N", 1: "T"},
"src_key_padding_mask": {0: "N", 1: "T"},
"y": {0: "T", 1: "N"},
},
)
options = ort.SessionOptions()
options.inter_op_num_threads = 1
options.intra_op_num_threads = 1
session = ort.InferenceSession(
filename,
sess_options=options,
)
input_nodes = session.get_inputs()
inputs = {
input_nodes[0].name: x.numpy(),
input_nodes[1].name: pos_emb.numpy(),
input_nodes[2].name: src_key_padding_mask.numpy(),
}
onnx_y = session.run(["y"], inputs)[0]
onnx_y = torch.from_numpy(onnx_y)
torch_y = jit_model(x, pos_emb, src_key_padding_mask)
assert torch.allclose(onnx_y, torch_y, atol=1e-05), (
(onnx_y - torch_y).abs().max()
)
print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape)
os.remove(filename)
def test_conformer_encoder():
filename = "conformer_encoder.onnx"
opset_version = 11
N = 3
T = 15
d_model = 512
nhead = 8
dim_feedforward = 2048
dropout = 0.1
layer_dropout = 0.075
cnn_module_kernel = 31
causal = False
num_encoder_layers = 12
x = torch.rand(N, T, d_model)
x_lens = torch.full((N,), fill_value=T, dtype=torch.int64)
src_key_padding_mask = make_pad_mask(x_lens)
encoder_pos = RelPositionalEncoding(d_model, dropout)
encoder_pos.eval()
encoder_pos = convert_scaled_to_non_scaled(encoder_pos, inplace=True)
x, pos_emb = encoder_pos(x)
x = x.permute(1, 0, 2)
encoder_layer = ConformerEncoderLayer(
d_model,
nhead,
dim_feedforward,
dropout,
layer_dropout,
cnn_module_kernel,
causal,
)
encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
encoder.eval()
encoder = convert_scaled_to_non_scaled(encoder, inplace=True)
jit_model = torch.jit.trace(encoder, (x, pos_emb, src_key_padding_mask))
torch.onnx.export(
encoder,
(x, pos_emb, src_key_padding_mask),
filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "pos_emb", "src_key_padding_mask"],
output_names=["y"],
dynamic_axes={
"x": {0: "T", 1: "N"},
"pos_emb": {0: "N", 1: "T"},
"src_key_padding_mask": {0: "N", 1: "T"},
"y": {0: "T", 1: "N"},
},
)
options = ort.SessionOptions()
options.inter_op_num_threads = 1
options.intra_op_num_threads = 1
session = ort.InferenceSession(
filename,
sess_options=options,
)
input_nodes = session.get_inputs()
inputs = {
input_nodes[0].name: x.numpy(),
input_nodes[1].name: pos_emb.numpy(),
input_nodes[2].name: src_key_padding_mask.numpy(),
}
onnx_y = session.run(["y"], inputs)[0]
onnx_y = torch.from_numpy(onnx_y)
torch_y = jit_model(x, pos_emb, src_key_padding_mask)
assert torch.allclose(onnx_y, torch_y, atol=1e-05), (
(onnx_y - torch_y).abs().max()
)
print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape)
os.remove(filename)
def test_conformer():
filename = "conformer.onnx"
opset_version = 11
N = 3
T = 15
num_features = 80
x = torch.rand(N, T, num_features)
x_lens = torch.full((N,), fill_value=T, dtype=torch.int64)
conformer = Conformer(num_features=num_features)
conformer.eval()
conformer = convert_scaled_to_non_scaled(conformer, inplace=True)
jit_model = torch.jit.trace(conformer, (x, x_lens))
torch.onnx.export(
conformer,
(x, x_lens),
filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "x_lens"],
output_names=["y", "y_lens"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"y": {0: "N", 1: "T"},
"y_lens": {0: "N"},
},
)
options = ort.SessionOptions()
options.inter_op_num_threads = 1
options.intra_op_num_threads = 1
session = ort.InferenceSession(
filename,
sess_options=options,
)
input_nodes = session.get_inputs()
inputs = {
input_nodes[0].name: x.numpy(),
input_nodes[1].name: x_lens.numpy(),
}
onnx_y, onnx_y_lens = session.run(["y", "y_lens"], inputs)
onnx_y = torch.from_numpy(onnx_y)
onnx_y_lens = torch.from_numpy(onnx_y_lens)
torch_y, torch_y_lens = jit_model(x, x_lens)
assert torch.allclose(onnx_y, torch_y, atol=1e-05), (
(onnx_y - torch_y).abs().max()
)
assert torch.allclose(onnx_y_lens, torch_y_lens, atol=1e-05), (
(onnx_y_lens - torch_y_lens).abs().max()
)
print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape)
print(onnx_y_lens, torch_y_lens)
os.remove(filename)
@torch.no_grad()
def main():
test_conv2d_subsampling()
test_rel_pos()
test_conformer_encoder_layer()
test_conformer_encoder()
test_conformer()
if __name__ == "__main__":
torch.manual_seed(20221011)
main()

View File

@ -7,5 +7,4 @@ multi_quantization
onnx
onnxruntime
--extra-index-url https://pypi.ngc.nvidia.com
onnx_graphsurgeon
dill