mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 18:42:19 +00:00
exporting projection layers of joiner separately for onnx (#584)
* exporting projection layers of joiner separately for onnx
This commit is contained in:
parent
0019463c83
commit
f3db4ea871
@ -58,7 +58,9 @@ log "Decode with ONNX models"
|
|||||||
--jit-filename $repo/exp/cpu_jit.pt \
|
--jit-filename $repo/exp/cpu_jit.pt \
|
||||||
--onnx-encoder-filename $repo/exp/encoder.onnx \
|
--onnx-encoder-filename $repo/exp/encoder.onnx \
|
||||||
--onnx-decoder-filename $repo/exp/decoder.onnx \
|
--onnx-decoder-filename $repo/exp/decoder.onnx \
|
||||||
--onnx-joiner-filename $repo/exp/joiner.onnx
|
--onnx-joiner-filename $repo/exp/joiner.onnx \
|
||||||
|
--onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \
|
||||||
|
--onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx
|
||||||
|
|
||||||
./pruned_transducer_stateless3/onnx_check_all_in_one.py \
|
./pruned_transducer_stateless3/onnx_check_all_in_one.py \
|
||||||
--jit-filename $repo/exp/cpu_jit.pt \
|
--jit-filename $repo/exp/cpu_jit.pt \
|
||||||
|
@ -62,12 +62,15 @@ It will generates 3 files: `encoder_jit_trace.pt`,
|
|||||||
--avg 10 \
|
--avg 10 \
|
||||||
--onnx 1
|
--onnx 1
|
||||||
|
|
||||||
It will generate the following three files in the given `exp_dir`.
|
It will generate the following six files in the given `exp_dir`.
|
||||||
Check `onnx_check.py` for how to use them.
|
Check `onnx_check.py` for how to use them.
|
||||||
|
|
||||||
- encoder.onnx
|
- encoder.onnx
|
||||||
- decoder.onnx
|
- decoder.onnx
|
||||||
- joiner.onnx
|
- joiner.onnx
|
||||||
|
- joiner_encoder_proj.onnx
|
||||||
|
- joiner_decoder_proj.onnx
|
||||||
|
- all_in_one.onnx
|
||||||
|
|
||||||
|
|
||||||
(4) Export `model.state_dict()`
|
(4) Export `model.state_dict()`
|
||||||
@ -115,6 +118,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import onnx_graphsurgeon as gs
|
||||||
import onnx
|
import onnx
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
@ -218,6 +222,9 @@ def get_parser():
|
|||||||
- encoder.onnx
|
- encoder.onnx
|
||||||
- decoder.onnx
|
- decoder.onnx
|
||||||
- joiner.onnx
|
- joiner.onnx
|
||||||
|
- joiner_encoder_proj.onnx
|
||||||
|
- joiner_decoder_proj.onnx
|
||||||
|
- all_in_one.onnx
|
||||||
|
|
||||||
Check ./onnx_check.py and ./onnx_pretrained.py for how to use them.
|
Check ./onnx_check.py and ./onnx_pretrained.py for how to use them.
|
||||||
""",
|
""",
|
||||||
@ -485,14 +492,11 @@ def export_joiner_model_onnx(
|
|||||||
|
|
||||||
- joiner_out: a tensor of shape (N, vocab_size)
|
- joiner_out: a tensor of shape (N, vocab_size)
|
||||||
|
|
||||||
Note: The argument project_input is fixed to True. A user should not
|
|
||||||
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
|
||||||
will do that for the user.
|
|
||||||
"""
|
"""
|
||||||
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||||
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
encoder_out = torch.rand(1, 1, 1, encoder_out_dim, dtype=torch.float32)
|
||||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
decoder_out = torch.rand(1, 1, 1, decoder_out_dim, dtype=torch.float32)
|
||||||
|
|
||||||
project_input = True
|
project_input = True
|
||||||
# Note: It uses torch.jit.trace() internally
|
# Note: It uses torch.jit.trace() internally
|
||||||
@ -510,10 +514,63 @@ def export_joiner_model_onnx(
|
|||||||
"logit": {0: "N"},
|
"logit": {0: "N"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
torch.onnx.export(
|
||||||
|
joiner_model.encoder_proj,
|
||||||
|
(encoder_out.squeeze(0).squeeze(0)),
|
||||||
|
str(joiner_filename).replace(".onnx", "_encoder_proj.onnx"),
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["encoder_out"],
|
||||||
|
output_names=["encoder_proj"],
|
||||||
|
dynamic_axes={
|
||||||
|
"encoder_out": {0: "N"},
|
||||||
|
"encoder_proj": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
torch.onnx.export(
|
||||||
|
joiner_model.decoder_proj,
|
||||||
|
(decoder_out.squeeze(0).squeeze(0)),
|
||||||
|
str(joiner_filename).replace(".onnx", "_decoder_proj.onnx"),
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["decoder_out"],
|
||||||
|
output_names=["decoder_proj"],
|
||||||
|
dynamic_axes={
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
"decoder_proj": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
logging.info(f"Saved to {joiner_filename}")
|
logging.info(f"Saved to {joiner_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def add_variables(
|
||||||
|
model: nn.Module, combined_model: onnx.ModelProto
|
||||||
|
) -> onnx.ModelProto:
|
||||||
|
graph = gs.import_onnx(combined_model)
|
||||||
|
|
||||||
|
blank_id = model.decoder.blank_id
|
||||||
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
|
context_size = model.decoder.context_size
|
||||||
|
|
||||||
|
node = gs.Node(
|
||||||
|
op="Identity",
|
||||||
|
name="constants_lm",
|
||||||
|
attrs={
|
||||||
|
"blank_id": blank_id,
|
||||||
|
"unk_id": unk_id,
|
||||||
|
"context_size": context_size,
|
||||||
|
},
|
||||||
|
inputs=[],
|
||||||
|
outputs=[],
|
||||||
|
)
|
||||||
|
graph.nodes.append(node)
|
||||||
|
|
||||||
|
graph = gs.export_onnx(graph)
|
||||||
|
return graph
|
||||||
|
|
||||||
|
|
||||||
def export_all_in_one_onnx(
|
def export_all_in_one_onnx(
|
||||||
|
model: nn.Module,
|
||||||
encoder_filename: str,
|
encoder_filename: str,
|
||||||
decoder_filename: str,
|
decoder_filename: str,
|
||||||
joiner_filename: str,
|
joiner_filename: str,
|
||||||
@ -522,10 +579,22 @@ def export_all_in_one_onnx(
|
|||||||
encoder_onnx = onnx.load(encoder_filename)
|
encoder_onnx = onnx.load(encoder_filename)
|
||||||
decoder_onnx = onnx.load(decoder_filename)
|
decoder_onnx = onnx.load(decoder_filename)
|
||||||
joiner_onnx = onnx.load(joiner_filename)
|
joiner_onnx = onnx.load(joiner_filename)
|
||||||
|
joiner_encoder_proj_onnx = onnx.load(
|
||||||
|
str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
|
||||||
|
)
|
||||||
|
joiner_decoder_proj_onnx = onnx.load(
|
||||||
|
str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
|
||||||
|
)
|
||||||
|
|
||||||
encoder_onnx = onnx.compose.add_prefix(encoder_onnx, prefix="encoder/")
|
encoder_onnx = onnx.compose.add_prefix(encoder_onnx, prefix="encoder/")
|
||||||
decoder_onnx = onnx.compose.add_prefix(decoder_onnx, prefix="decoder/")
|
decoder_onnx = onnx.compose.add_prefix(decoder_onnx, prefix="decoder/")
|
||||||
joiner_onnx = onnx.compose.add_prefix(joiner_onnx, prefix="joiner/")
|
joiner_onnx = onnx.compose.add_prefix(joiner_onnx, prefix="joiner/")
|
||||||
|
joiner_encoder_proj_onnx = onnx.compose.add_prefix(
|
||||||
|
joiner_encoder_proj_onnx, prefix="joiner_encoder_proj/"
|
||||||
|
)
|
||||||
|
joiner_decoder_proj_onnx = onnx.compose.add_prefix(
|
||||||
|
joiner_decoder_proj_onnx, prefix="joiner_decoder_proj/"
|
||||||
|
)
|
||||||
|
|
||||||
combined_model = onnx.compose.merge_models(
|
combined_model = onnx.compose.merge_models(
|
||||||
encoder_onnx, decoder_onnx, io_map={}
|
encoder_onnx, decoder_onnx, io_map={}
|
||||||
@ -533,6 +602,13 @@ def export_all_in_one_onnx(
|
|||||||
combined_model = onnx.compose.merge_models(
|
combined_model = onnx.compose.merge_models(
|
||||||
combined_model, joiner_onnx, io_map={}
|
combined_model, joiner_onnx, io_map={}
|
||||||
)
|
)
|
||||||
|
combined_model = onnx.compose.merge_models(
|
||||||
|
combined_model, joiner_encoder_proj_onnx, io_map={}
|
||||||
|
)
|
||||||
|
combined_model = onnx.compose.merge_models(
|
||||||
|
combined_model, joiner_decoder_proj_onnx, io_map={}
|
||||||
|
)
|
||||||
|
combined_model = add_variables(model, combined_model)
|
||||||
onnx.save(combined_model, all_in_one_filename)
|
onnx.save(combined_model, all_in_one_filename)
|
||||||
logging.info(f"Saved to {all_in_one_filename}")
|
logging.info(f"Saved to {all_in_one_filename}")
|
||||||
|
|
||||||
@ -631,6 +707,7 @@ def main():
|
|||||||
|
|
||||||
all_in_one_filename = params.exp_dir / "all_in_one.onnx"
|
all_in_one_filename = params.exp_dir / "all_in_one.onnx"
|
||||||
export_all_in_one_onnx(
|
export_all_in_one_onnx(
|
||||||
|
model,
|
||||||
encoder_filename,
|
encoder_filename,
|
||||||
decoder_filename,
|
decoder_filename,
|
||||||
joiner_filename,
|
joiner_filename,
|
||||||
|
@ -63,6 +63,20 @@ def get_parser():
|
|||||||
help="Path to the onnx joiner model",
|
help="Path to the onnx joiner model",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--onnx-joiner-encoder-proj-filename",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="Path to the onnx joiner encoder projection model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--onnx-joiner-decoder-proj-filename",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="Path to the onnx joiner decoder projection model",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -126,17 +140,27 @@ def test_decoder(
|
|||||||
def test_joiner(
|
def test_joiner(
|
||||||
model: torch.jit.ScriptModule,
|
model: torch.jit.ScriptModule,
|
||||||
joiner_session: ort.InferenceSession,
|
joiner_session: ort.InferenceSession,
|
||||||
|
joiner_encoder_proj_session: ort.InferenceSession,
|
||||||
|
joiner_decoder_proj_session: ort.InferenceSession,
|
||||||
):
|
):
|
||||||
joiner_inputs = joiner_session.get_inputs()
|
joiner_inputs = joiner_session.get_inputs()
|
||||||
assert joiner_inputs[0].name == "encoder_out"
|
assert joiner_inputs[0].name == "encoder_out"
|
||||||
assert joiner_inputs[0].shape == ["N", 512]
|
assert joiner_inputs[0].shape == ["N", 1, 1, 512]
|
||||||
|
|
||||||
assert joiner_inputs[1].name == "decoder_out"
|
assert joiner_inputs[1].name == "decoder_out"
|
||||||
assert joiner_inputs[1].shape == ["N", 512]
|
assert joiner_inputs[1].shape == ["N", 1, 1, 512]
|
||||||
|
|
||||||
|
joiner_encoder_proj_inputs = joiner_encoder_proj_session.get_inputs()
|
||||||
|
assert joiner_encoder_proj_inputs[0].name == "encoder_out"
|
||||||
|
assert joiner_encoder_proj_inputs[0].shape == ["N", 512]
|
||||||
|
|
||||||
|
joiner_decoder_proj_inputs = joiner_decoder_proj_session.get_inputs()
|
||||||
|
assert joiner_decoder_proj_inputs[0].name == "decoder_out"
|
||||||
|
assert joiner_decoder_proj_inputs[0].shape == ["N", 512]
|
||||||
|
|
||||||
for N in [1, 5, 10]:
|
for N in [1, 5, 10]:
|
||||||
encoder_out = torch.rand(N, 512)
|
encoder_out = torch.rand(N, 1, 1, 512)
|
||||||
decoder_out = torch.rand(N, 512)
|
decoder_out = torch.rand(N, 1, 1, 512)
|
||||||
|
|
||||||
joiner_inputs = {
|
joiner_inputs = {
|
||||||
"encoder_out": encoder_out.numpy(),
|
"encoder_out": encoder_out.numpy(),
|
||||||
@ -154,6 +178,44 @@ def test_joiner(
|
|||||||
(joiner_out - torch_joiner_out).abs().max()
|
(joiner_out - torch_joiner_out).abs().max()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
joiner_encoder_proj_inputs = {
|
||||||
|
"encoder_out": encoder_out.squeeze(1).squeeze(1).numpy()
|
||||||
|
}
|
||||||
|
joiner_encoder_proj_out = joiner_encoder_proj_session.run(
|
||||||
|
["encoder_proj"], joiner_encoder_proj_inputs
|
||||||
|
)[0]
|
||||||
|
joiner_encoder_proj_out = torch.from_numpy(joiner_encoder_proj_out)
|
||||||
|
|
||||||
|
torch_joiner_encoder_proj_out = model.joiner.encoder_proj(
|
||||||
|
encoder_out.squeeze(1).squeeze(1)
|
||||||
|
)
|
||||||
|
assert torch.allclose(
|
||||||
|
joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5
|
||||||
|
), (
|
||||||
|
(joiner_encoder_proj_out - torch_joiner_encoder_proj_out)
|
||||||
|
.abs()
|
||||||
|
.max()
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner_decoder_proj_inputs = {
|
||||||
|
"decoder_out": decoder_out.squeeze(1).squeeze(1).numpy()
|
||||||
|
}
|
||||||
|
joiner_decoder_proj_out = joiner_decoder_proj_session.run(
|
||||||
|
["decoder_proj"], joiner_decoder_proj_inputs
|
||||||
|
)[0]
|
||||||
|
joiner_decoder_proj_out = torch.from_numpy(joiner_decoder_proj_out)
|
||||||
|
|
||||||
|
torch_joiner_decoder_proj_out = model.joiner.decoder_proj(
|
||||||
|
decoder_out.squeeze(1).squeeze(1)
|
||||||
|
)
|
||||||
|
assert torch.allclose(
|
||||||
|
joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5
|
||||||
|
), (
|
||||||
|
(joiner_decoder_proj_out - torch_joiner_decoder_proj_out)
|
||||||
|
.abs()
|
||||||
|
.max()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
@ -185,7 +247,20 @@ def main():
|
|||||||
args.onnx_joiner_filename,
|
args.onnx_joiner_filename,
|
||||||
sess_options=options,
|
sess_options=options,
|
||||||
)
|
)
|
||||||
test_joiner(model, joiner_session)
|
joiner_encoder_proj_session = ort.InferenceSession(
|
||||||
|
args.onnx_joiner_encoder_proj_filename,
|
||||||
|
sess_options=options,
|
||||||
|
)
|
||||||
|
joiner_decoder_proj_session = ort.InferenceSession(
|
||||||
|
args.onnx_joiner_decoder_proj_filename,
|
||||||
|
sess_options=options,
|
||||||
|
)
|
||||||
|
test_joiner(
|
||||||
|
model,
|
||||||
|
joiner_session,
|
||||||
|
joiner_encoder_proj_session,
|
||||||
|
joiner_decoder_proj_session,
|
||||||
|
)
|
||||||
logging.info("Finished checking ONNX models")
|
logging.info("Finished checking ONNX models")
|
||||||
|
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ You can use the following command to get the exported models:
|
|||||||
|
|
||||||
Usage of this script:
|
Usage of this script:
|
||||||
|
|
||||||
./pruned_transducer_stateless3/jit_trace_pretrained.py \
|
./pruned_transducer_stateless3/onnx_pretrained.py \
|
||||||
--encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \
|
--encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \
|
||||||
--decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \
|
--decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \
|
||||||
--joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \
|
--joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \
|
||||||
@ -194,6 +194,7 @@ def greedy_search(
|
|||||||
decoder_input_nodes[0].name: decoder_input.numpy(),
|
decoder_input_nodes[0].name: decoder_input.numpy(),
|
||||||
},
|
},
|
||||||
)[0].squeeze(1)
|
)[0].squeeze(1)
|
||||||
|
decoder_out = torch.from_numpy(decoder_out)
|
||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
for batch_size in batch_size_list:
|
for batch_size in batch_size_list:
|
||||||
@ -209,11 +210,17 @@ def greedy_search(
|
|||||||
logits = joiner.run(
|
logits = joiner.run(
|
||||||
[joiner_output_nodes[0].name],
|
[joiner_output_nodes[0].name],
|
||||||
{
|
{
|
||||||
joiner_input_nodes[0].name: current_encoder_out.numpy(),
|
joiner_input_nodes[0]
|
||||||
joiner_input_nodes[1].name: decoder_out,
|
.name: current_encoder_out.unsqueeze(1)
|
||||||
|
.unsqueeze(1)
|
||||||
|
.numpy(),
|
||||||
|
joiner_input_nodes[1]
|
||||||
|
.name: decoder_out.unsqueeze(1)
|
||||||
|
.unsqueeze(1)
|
||||||
|
.numpy(),
|
||||||
},
|
},
|
||||||
)[0]
|
)[0]
|
||||||
logits = torch.from_numpy(logits)
|
logits = torch.from_numpy(logits).squeeze(1).squeeze(1)
|
||||||
# logits'shape (batch_size, vocab_size)
|
# logits'shape (batch_size, vocab_size)
|
||||||
|
|
||||||
assert logits.ndim == 2, logits.shape
|
assert logits.ndim == 2, logits.shape
|
||||||
@ -236,6 +243,7 @@ def greedy_search(
|
|||||||
decoder_input_nodes[0].name: decoder_input.numpy(),
|
decoder_input_nodes[0].name: decoder_input.numpy(),
|
||||||
},
|
},
|
||||||
)[0].squeeze(1)
|
)[0].squeeze(1)
|
||||||
|
decoder_out = torch.from_numpy(decoder_out)
|
||||||
|
|
||||||
sorted_ans = [h[context_size:] for h in hyps]
|
sorted_ans = [h[context_size:] for h in hyps]
|
||||||
ans = []
|
ans = []
|
||||||
|
Loading…
x
Reference in New Issue
Block a user