mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
exporting projection layers of joiner separately for onnx (#584)
* exporting projection layers of joiner separately for onnx
This commit is contained in:
parent
0019463c83
commit
f3db4ea871
@ -58,7 +58,9 @@ log "Decode with ONNX models"
|
||||
--jit-filename $repo/exp/cpu_jit.pt \
|
||||
--onnx-encoder-filename $repo/exp/encoder.onnx \
|
||||
--onnx-decoder-filename $repo/exp/decoder.onnx \
|
||||
--onnx-joiner-filename $repo/exp/joiner.onnx
|
||||
--onnx-joiner-filename $repo/exp/joiner.onnx \
|
||||
--onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \
|
||||
--onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx
|
||||
|
||||
./pruned_transducer_stateless3/onnx_check_all_in_one.py \
|
||||
--jit-filename $repo/exp/cpu_jit.pt \
|
||||
|
@ -62,12 +62,15 @@ It will generates 3 files: `encoder_jit_trace.pt`,
|
||||
--avg 10 \
|
||||
--onnx 1
|
||||
|
||||
It will generate the following three files in the given `exp_dir`.
|
||||
It will generate the following six files in the given `exp_dir`.
|
||||
Check `onnx_check.py` for how to use them.
|
||||
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
- joiner.onnx
|
||||
- joiner_encoder_proj.onnx
|
||||
- joiner_decoder_proj.onnx
|
||||
- all_in_one.onnx
|
||||
|
||||
|
||||
(4) Export `model.state_dict()`
|
||||
@ -115,6 +118,7 @@ import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import onnx_graphsurgeon as gs
|
||||
import onnx
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
@ -218,6 +222,9 @@ def get_parser():
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
- joiner.onnx
|
||||
- joiner_encoder_proj.onnx
|
||||
- joiner_decoder_proj.onnx
|
||||
- all_in_one.onnx
|
||||
|
||||
Check ./onnx_check.py and ./onnx_pretrained.py for how to use them.
|
||||
""",
|
||||
@ -485,14 +492,11 @@ def export_joiner_model_onnx(
|
||||
|
||||
- joiner_out: a tensor of shape (N, vocab_size)
|
||||
|
||||
Note: The argument project_input is fixed to True. A user should not
|
||||
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
||||
will do that for the user.
|
||||
"""
|
||||
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
encoder_out = torch.rand(1, 1, 1, encoder_out_dim, dtype=torch.float32)
|
||||
decoder_out = torch.rand(1, 1, 1, decoder_out_dim, dtype=torch.float32)
|
||||
|
||||
project_input = True
|
||||
# Note: It uses torch.jit.trace() internally
|
||||
@ -510,10 +514,63 @@ def export_joiner_model_onnx(
|
||||
"logit": {0: "N"},
|
||||
},
|
||||
)
|
||||
torch.onnx.export(
|
||||
joiner_model.encoder_proj,
|
||||
(encoder_out.squeeze(0).squeeze(0)),
|
||||
str(joiner_filename).replace(".onnx", "_encoder_proj.onnx"),
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["encoder_out"],
|
||||
output_names=["encoder_proj"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"encoder_proj": {0: "N"},
|
||||
},
|
||||
)
|
||||
torch.onnx.export(
|
||||
joiner_model.decoder_proj,
|
||||
(decoder_out.squeeze(0).squeeze(0)),
|
||||
str(joiner_filename).replace(".onnx", "_decoder_proj.onnx"),
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["decoder_out"],
|
||||
output_names=["decoder_proj"],
|
||||
dynamic_axes={
|
||||
"decoder_out": {0: "N"},
|
||||
"decoder_proj": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
|
||||
def add_variables(
|
||||
model: nn.Module, combined_model: onnx.ModelProto
|
||||
) -> onnx.ModelProto:
|
||||
graph = gs.import_onnx(combined_model)
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
unk_id = getattr(model, "unk_id", blank_id)
|
||||
context_size = model.decoder.context_size
|
||||
|
||||
node = gs.Node(
|
||||
op="Identity",
|
||||
name="constants_lm",
|
||||
attrs={
|
||||
"blank_id": blank_id,
|
||||
"unk_id": unk_id,
|
||||
"context_size": context_size,
|
||||
},
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
graph.nodes.append(node)
|
||||
|
||||
graph = gs.export_onnx(graph)
|
||||
return graph
|
||||
|
||||
|
||||
def export_all_in_one_onnx(
|
||||
model: nn.Module,
|
||||
encoder_filename: str,
|
||||
decoder_filename: str,
|
||||
joiner_filename: str,
|
||||
@ -522,10 +579,22 @@ def export_all_in_one_onnx(
|
||||
encoder_onnx = onnx.load(encoder_filename)
|
||||
decoder_onnx = onnx.load(decoder_filename)
|
||||
joiner_onnx = onnx.load(joiner_filename)
|
||||
joiner_encoder_proj_onnx = onnx.load(
|
||||
str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
|
||||
)
|
||||
joiner_decoder_proj_onnx = onnx.load(
|
||||
str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
|
||||
)
|
||||
|
||||
encoder_onnx = onnx.compose.add_prefix(encoder_onnx, prefix="encoder/")
|
||||
decoder_onnx = onnx.compose.add_prefix(decoder_onnx, prefix="decoder/")
|
||||
joiner_onnx = onnx.compose.add_prefix(joiner_onnx, prefix="joiner/")
|
||||
joiner_encoder_proj_onnx = onnx.compose.add_prefix(
|
||||
joiner_encoder_proj_onnx, prefix="joiner_encoder_proj/"
|
||||
)
|
||||
joiner_decoder_proj_onnx = onnx.compose.add_prefix(
|
||||
joiner_decoder_proj_onnx, prefix="joiner_decoder_proj/"
|
||||
)
|
||||
|
||||
combined_model = onnx.compose.merge_models(
|
||||
encoder_onnx, decoder_onnx, io_map={}
|
||||
@ -533,6 +602,13 @@ def export_all_in_one_onnx(
|
||||
combined_model = onnx.compose.merge_models(
|
||||
combined_model, joiner_onnx, io_map={}
|
||||
)
|
||||
combined_model = onnx.compose.merge_models(
|
||||
combined_model, joiner_encoder_proj_onnx, io_map={}
|
||||
)
|
||||
combined_model = onnx.compose.merge_models(
|
||||
combined_model, joiner_decoder_proj_onnx, io_map={}
|
||||
)
|
||||
combined_model = add_variables(model, combined_model)
|
||||
onnx.save(combined_model, all_in_one_filename)
|
||||
logging.info(f"Saved to {all_in_one_filename}")
|
||||
|
||||
@ -631,6 +707,7 @@ def main():
|
||||
|
||||
all_in_one_filename = params.exp_dir / "all_in_one.onnx"
|
||||
export_all_in_one_onnx(
|
||||
model,
|
||||
encoder_filename,
|
||||
decoder_filename,
|
||||
joiner_filename,
|
||||
|
@ -63,6 +63,20 @@ def get_parser():
|
||||
help="Path to the onnx joiner model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-joiner-encoder-proj-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the onnx joiner encoder projection model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-joiner-decoder-proj-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the onnx joiner decoder projection model",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -126,17 +140,27 @@ def test_decoder(
|
||||
def test_joiner(
|
||||
model: torch.jit.ScriptModule,
|
||||
joiner_session: ort.InferenceSession,
|
||||
joiner_encoder_proj_session: ort.InferenceSession,
|
||||
joiner_decoder_proj_session: ort.InferenceSession,
|
||||
):
|
||||
joiner_inputs = joiner_session.get_inputs()
|
||||
assert joiner_inputs[0].name == "encoder_out"
|
||||
assert joiner_inputs[0].shape == ["N", 512]
|
||||
assert joiner_inputs[0].shape == ["N", 1, 1, 512]
|
||||
|
||||
assert joiner_inputs[1].name == "decoder_out"
|
||||
assert joiner_inputs[1].shape == ["N", 512]
|
||||
assert joiner_inputs[1].shape == ["N", 1, 1, 512]
|
||||
|
||||
joiner_encoder_proj_inputs = joiner_encoder_proj_session.get_inputs()
|
||||
assert joiner_encoder_proj_inputs[0].name == "encoder_out"
|
||||
assert joiner_encoder_proj_inputs[0].shape == ["N", 512]
|
||||
|
||||
joiner_decoder_proj_inputs = joiner_decoder_proj_session.get_inputs()
|
||||
assert joiner_decoder_proj_inputs[0].name == "decoder_out"
|
||||
assert joiner_decoder_proj_inputs[0].shape == ["N", 512]
|
||||
|
||||
for N in [1, 5, 10]:
|
||||
encoder_out = torch.rand(N, 512)
|
||||
decoder_out = torch.rand(N, 512)
|
||||
encoder_out = torch.rand(N, 1, 1, 512)
|
||||
decoder_out = torch.rand(N, 1, 1, 512)
|
||||
|
||||
joiner_inputs = {
|
||||
"encoder_out": encoder_out.numpy(),
|
||||
@ -154,6 +178,44 @@ def test_joiner(
|
||||
(joiner_out - torch_joiner_out).abs().max()
|
||||
)
|
||||
|
||||
joiner_encoder_proj_inputs = {
|
||||
"encoder_out": encoder_out.squeeze(1).squeeze(1).numpy()
|
||||
}
|
||||
joiner_encoder_proj_out = joiner_encoder_proj_session.run(
|
||||
["encoder_proj"], joiner_encoder_proj_inputs
|
||||
)[0]
|
||||
joiner_encoder_proj_out = torch.from_numpy(joiner_encoder_proj_out)
|
||||
|
||||
torch_joiner_encoder_proj_out = model.joiner.encoder_proj(
|
||||
encoder_out.squeeze(1).squeeze(1)
|
||||
)
|
||||
assert torch.allclose(
|
||||
joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5
|
||||
), (
|
||||
(joiner_encoder_proj_out - torch_joiner_encoder_proj_out)
|
||||
.abs()
|
||||
.max()
|
||||
)
|
||||
|
||||
joiner_decoder_proj_inputs = {
|
||||
"decoder_out": decoder_out.squeeze(1).squeeze(1).numpy()
|
||||
}
|
||||
joiner_decoder_proj_out = joiner_decoder_proj_session.run(
|
||||
["decoder_proj"], joiner_decoder_proj_inputs
|
||||
)[0]
|
||||
joiner_decoder_proj_out = torch.from_numpy(joiner_decoder_proj_out)
|
||||
|
||||
torch_joiner_decoder_proj_out = model.joiner.decoder_proj(
|
||||
decoder_out.squeeze(1).squeeze(1)
|
||||
)
|
||||
assert torch.allclose(
|
||||
joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5
|
||||
), (
|
||||
(joiner_decoder_proj_out - torch_joiner_decoder_proj_out)
|
||||
.abs()
|
||||
.max()
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
@ -185,7 +247,20 @@ def main():
|
||||
args.onnx_joiner_filename,
|
||||
sess_options=options,
|
||||
)
|
||||
test_joiner(model, joiner_session)
|
||||
joiner_encoder_proj_session = ort.InferenceSession(
|
||||
args.onnx_joiner_encoder_proj_filename,
|
||||
sess_options=options,
|
||||
)
|
||||
joiner_decoder_proj_session = ort.InferenceSession(
|
||||
args.onnx_joiner_decoder_proj_filename,
|
||||
sess_options=options,
|
||||
)
|
||||
test_joiner(
|
||||
model,
|
||||
joiner_session,
|
||||
joiner_encoder_proj_session,
|
||||
joiner_decoder_proj_session,
|
||||
)
|
||||
logging.info("Finished checking ONNX models")
|
||||
|
||||
|
||||
|
@ -27,7 +27,7 @@ You can use the following command to get the exported models:
|
||||
|
||||
Usage of this script:
|
||||
|
||||
./pruned_transducer_stateless3/jit_trace_pretrained.py \
|
||||
./pruned_transducer_stateless3/onnx_pretrained.py \
|
||||
--encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \
|
||||
--decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \
|
||||
--joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \
|
||||
@ -194,6 +194,7 @@ def greedy_search(
|
||||
decoder_input_nodes[0].name: decoder_input.numpy(),
|
||||
},
|
||||
)[0].squeeze(1)
|
||||
decoder_out = torch.from_numpy(decoder_out)
|
||||
|
||||
offset = 0
|
||||
for batch_size in batch_size_list:
|
||||
@ -209,11 +210,17 @@ def greedy_search(
|
||||
logits = joiner.run(
|
||||
[joiner_output_nodes[0].name],
|
||||
{
|
||||
joiner_input_nodes[0].name: current_encoder_out.numpy(),
|
||||
joiner_input_nodes[1].name: decoder_out,
|
||||
joiner_input_nodes[0]
|
||||
.name: current_encoder_out.unsqueeze(1)
|
||||
.unsqueeze(1)
|
||||
.numpy(),
|
||||
joiner_input_nodes[1]
|
||||
.name: decoder_out.unsqueeze(1)
|
||||
.unsqueeze(1)
|
||||
.numpy(),
|
||||
},
|
||||
)[0]
|
||||
logits = torch.from_numpy(logits)
|
||||
logits = torch.from_numpy(logits).squeeze(1).squeeze(1)
|
||||
# logits'shape (batch_size, vocab_size)
|
||||
|
||||
assert logits.ndim == 2, logits.shape
|
||||
@ -236,6 +243,7 @@ def greedy_search(
|
||||
decoder_input_nodes[0].name: decoder_input.numpy(),
|
||||
},
|
||||
)[0].squeeze(1)
|
||||
decoder_out = torch.from_numpy(decoder_out)
|
||||
|
||||
sorted_ans = [h[context_size:] for h in hyps]
|
||||
ans = []
|
||||
|
Loading…
x
Reference in New Issue
Block a user