exporting projection layers of joiner separately for onnx (#584)

* exporting projection layers of joiner separately for onnx
This commit is contained in:
Yunusemre 2022-10-11 13:22:28 +03:00 committed by GitHub
parent 0019463c83
commit f3db4ea871
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 178 additions and 16 deletions

View File

@ -58,7 +58,9 @@ log "Decode with ONNX models"
--jit-filename $repo/exp/cpu_jit.pt \
--onnx-encoder-filename $repo/exp/encoder.onnx \
--onnx-decoder-filename $repo/exp/decoder.onnx \
--onnx-joiner-filename $repo/exp/joiner.onnx
--onnx-joiner-filename $repo/exp/joiner.onnx \
--onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \
--onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx
./pruned_transducer_stateless3/onnx_check_all_in_one.py \
--jit-filename $repo/exp/cpu_jit.pt \

View File

@ -62,12 +62,15 @@ It will generates 3 files: `encoder_jit_trace.pt`,
--avg 10 \
--onnx 1
It will generate the following three files in the given `exp_dir`.
It will generate the following six files in the given `exp_dir`.
Check `onnx_check.py` for how to use them.
- encoder.onnx
- decoder.onnx
- joiner.onnx
- joiner_encoder_proj.onnx
- joiner_decoder_proj.onnx
- all_in_one.onnx
(4) Export `model.state_dict()`
@ -115,6 +118,7 @@ import argparse
import logging
from pathlib import Path
import onnx_graphsurgeon as gs
import onnx
import sentencepiece as spm
import torch
@ -218,6 +222,9 @@ def get_parser():
- encoder.onnx
- decoder.onnx
- joiner.onnx
- joiner_encoder_proj.onnx
- joiner_decoder_proj.onnx
- all_in_one.onnx
Check ./onnx_check.py and ./onnx_pretrained.py for how to use them.
""",
@ -485,14 +492,11 @@ def export_joiner_model_onnx(
- joiner_out: a tensor of shape (N, vocab_size)
Note: The argument project_input is fixed to True. A user should not
project the encoder_out/decoder_out by himself/herself. The exported joiner
will do that for the user.
"""
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
encoder_out = torch.rand(1, 1, 1, encoder_out_dim, dtype=torch.float32)
decoder_out = torch.rand(1, 1, 1, decoder_out_dim, dtype=torch.float32)
project_input = True
# Note: It uses torch.jit.trace() internally
@ -510,10 +514,63 @@ def export_joiner_model_onnx(
"logit": {0: "N"},
},
)
torch.onnx.export(
joiner_model.encoder_proj,
(encoder_out.squeeze(0).squeeze(0)),
str(joiner_filename).replace(".onnx", "_encoder_proj.onnx"),
verbose=False,
opset_version=opset_version,
input_names=["encoder_out"],
output_names=["encoder_proj"],
dynamic_axes={
"encoder_out": {0: "N"},
"encoder_proj": {0: "N"},
},
)
torch.onnx.export(
joiner_model.decoder_proj,
(decoder_out.squeeze(0).squeeze(0)),
str(joiner_filename).replace(".onnx", "_decoder_proj.onnx"),
verbose=False,
opset_version=opset_version,
input_names=["decoder_out"],
output_names=["decoder_proj"],
dynamic_axes={
"decoder_out": {0: "N"},
"decoder_proj": {0: "N"},
},
)
logging.info(f"Saved to {joiner_filename}")
def add_variables(
model: nn.Module, combined_model: onnx.ModelProto
) -> onnx.ModelProto:
graph = gs.import_onnx(combined_model)
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
node = gs.Node(
op="Identity",
name="constants_lm",
attrs={
"blank_id": blank_id,
"unk_id": unk_id,
"context_size": context_size,
},
inputs=[],
outputs=[],
)
graph.nodes.append(node)
graph = gs.export_onnx(graph)
return graph
def export_all_in_one_onnx(
model: nn.Module,
encoder_filename: str,
decoder_filename: str,
joiner_filename: str,
@ -522,10 +579,22 @@ def export_all_in_one_onnx(
encoder_onnx = onnx.load(encoder_filename)
decoder_onnx = onnx.load(decoder_filename)
joiner_onnx = onnx.load(joiner_filename)
joiner_encoder_proj_onnx = onnx.load(
str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
)
joiner_decoder_proj_onnx = onnx.load(
str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
)
encoder_onnx = onnx.compose.add_prefix(encoder_onnx, prefix="encoder/")
decoder_onnx = onnx.compose.add_prefix(decoder_onnx, prefix="decoder/")
joiner_onnx = onnx.compose.add_prefix(joiner_onnx, prefix="joiner/")
joiner_encoder_proj_onnx = onnx.compose.add_prefix(
joiner_encoder_proj_onnx, prefix="joiner_encoder_proj/"
)
joiner_decoder_proj_onnx = onnx.compose.add_prefix(
joiner_decoder_proj_onnx, prefix="joiner_decoder_proj/"
)
combined_model = onnx.compose.merge_models(
encoder_onnx, decoder_onnx, io_map={}
@ -533,6 +602,13 @@ def export_all_in_one_onnx(
combined_model = onnx.compose.merge_models(
combined_model, joiner_onnx, io_map={}
)
combined_model = onnx.compose.merge_models(
combined_model, joiner_encoder_proj_onnx, io_map={}
)
combined_model = onnx.compose.merge_models(
combined_model, joiner_decoder_proj_onnx, io_map={}
)
combined_model = add_variables(model, combined_model)
onnx.save(combined_model, all_in_one_filename)
logging.info(f"Saved to {all_in_one_filename}")
@ -631,6 +707,7 @@ def main():
all_in_one_filename = params.exp_dir / "all_in_one.onnx"
export_all_in_one_onnx(
model,
encoder_filename,
decoder_filename,
joiner_filename,

View File

@ -63,6 +63,20 @@ def get_parser():
help="Path to the onnx joiner model",
)
parser.add_argument(
"--onnx-joiner-encoder-proj-filename",
required=True,
type=str,
help="Path to the onnx joiner encoder projection model",
)
parser.add_argument(
"--onnx-joiner-decoder-proj-filename",
required=True,
type=str,
help="Path to the onnx joiner decoder projection model",
)
return parser
@ -126,17 +140,27 @@ def test_decoder(
def test_joiner(
model: torch.jit.ScriptModule,
joiner_session: ort.InferenceSession,
joiner_encoder_proj_session: ort.InferenceSession,
joiner_decoder_proj_session: ort.InferenceSession,
):
joiner_inputs = joiner_session.get_inputs()
assert joiner_inputs[0].name == "encoder_out"
assert joiner_inputs[0].shape == ["N", 512]
assert joiner_inputs[0].shape == ["N", 1, 1, 512]
assert joiner_inputs[1].name == "decoder_out"
assert joiner_inputs[1].shape == ["N", 512]
assert joiner_inputs[1].shape == ["N", 1, 1, 512]
joiner_encoder_proj_inputs = joiner_encoder_proj_session.get_inputs()
assert joiner_encoder_proj_inputs[0].name == "encoder_out"
assert joiner_encoder_proj_inputs[0].shape == ["N", 512]
joiner_decoder_proj_inputs = joiner_decoder_proj_session.get_inputs()
assert joiner_decoder_proj_inputs[0].name == "decoder_out"
assert joiner_decoder_proj_inputs[0].shape == ["N", 512]
for N in [1, 5, 10]:
encoder_out = torch.rand(N, 512)
decoder_out = torch.rand(N, 512)
encoder_out = torch.rand(N, 1, 1, 512)
decoder_out = torch.rand(N, 1, 1, 512)
joiner_inputs = {
"encoder_out": encoder_out.numpy(),
@ -154,6 +178,44 @@ def test_joiner(
(joiner_out - torch_joiner_out).abs().max()
)
joiner_encoder_proj_inputs = {
"encoder_out": encoder_out.squeeze(1).squeeze(1).numpy()
}
joiner_encoder_proj_out = joiner_encoder_proj_session.run(
["encoder_proj"], joiner_encoder_proj_inputs
)[0]
joiner_encoder_proj_out = torch.from_numpy(joiner_encoder_proj_out)
torch_joiner_encoder_proj_out = model.joiner.encoder_proj(
encoder_out.squeeze(1).squeeze(1)
)
assert torch.allclose(
joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5
), (
(joiner_encoder_proj_out - torch_joiner_encoder_proj_out)
.abs()
.max()
)
joiner_decoder_proj_inputs = {
"decoder_out": decoder_out.squeeze(1).squeeze(1).numpy()
}
joiner_decoder_proj_out = joiner_decoder_proj_session.run(
["decoder_proj"], joiner_decoder_proj_inputs
)[0]
joiner_decoder_proj_out = torch.from_numpy(joiner_decoder_proj_out)
torch_joiner_decoder_proj_out = model.joiner.decoder_proj(
decoder_out.squeeze(1).squeeze(1)
)
assert torch.allclose(
joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5
), (
(joiner_decoder_proj_out - torch_joiner_decoder_proj_out)
.abs()
.max()
)
@torch.no_grad()
def main():
@ -185,7 +247,20 @@ def main():
args.onnx_joiner_filename,
sess_options=options,
)
test_joiner(model, joiner_session)
joiner_encoder_proj_session = ort.InferenceSession(
args.onnx_joiner_encoder_proj_filename,
sess_options=options,
)
joiner_decoder_proj_session = ort.InferenceSession(
args.onnx_joiner_decoder_proj_filename,
sess_options=options,
)
test_joiner(
model,
joiner_session,
joiner_encoder_proj_session,
joiner_decoder_proj_session,
)
logging.info("Finished checking ONNX models")

View File

@ -27,7 +27,7 @@ You can use the following command to get the exported models:
Usage of this script:
./pruned_transducer_stateless3/jit_trace_pretrained.py \
./pruned_transducer_stateless3/onnx_pretrained.py \
--encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \
--decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \
--joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \
@ -194,6 +194,7 @@ def greedy_search(
decoder_input_nodes[0].name: decoder_input.numpy(),
},
)[0].squeeze(1)
decoder_out = torch.from_numpy(decoder_out)
offset = 0
for batch_size in batch_size_list:
@ -209,11 +210,17 @@ def greedy_search(
logits = joiner.run(
[joiner_output_nodes[0].name],
{
joiner_input_nodes[0].name: current_encoder_out.numpy(),
joiner_input_nodes[1].name: decoder_out,
joiner_input_nodes[0]
.name: current_encoder_out.unsqueeze(1)
.unsqueeze(1)
.numpy(),
joiner_input_nodes[1]
.name: decoder_out.unsqueeze(1)
.unsqueeze(1)
.numpy(),
},
)[0]
logits = torch.from_numpy(logits)
logits = torch.from_numpy(logits).squeeze(1).squeeze(1)
# logits'shape (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
@ -236,6 +243,7 @@ def greedy_search(
decoder_input_nodes[0].name: decoder_input.numpy(),
},
)[0].squeeze(1)
decoder_out = torch.from_numpy(decoder_out)
sorted_ans = [h[context_size:] for h in hyps]
ans = []