Combine encoder/decoder/joiner into a single file.
This commit is contained in:
parent
49aaaf8021
commit
3ebb52aa9b
@ -20,13 +20,52 @@
|
||||
# to a single one using model averaging.
|
||||
"""
|
||||
Usage:
|
||||
|
||||
(1) Export to torchscript model
|
||||
|
||||
./pruned_transducer_stateless3/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--jit 1
|
||||
|
||||
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
|
||||
load it by `torch.jit.load("cpu_jit.pt")`.
|
||||
|
||||
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
|
||||
are on CPU. You can use `to("cuda")` to move it to a CUDA device.
|
||||
|
||||
(2) Export to ONNX format
|
||||
|
||||
./pruned_transducer_stateless3/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--onnx 1
|
||||
|
||||
It will generate the following files in the given `exp_dir`.
|
||||
See `onnx_check.py` to see how to use it.
|
||||
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
- joiner.onnx
|
||||
- all_in_one.onnx
|
||||
|
||||
The file all_in_one.onnx combines `encoder.onnx`, `decoder.onnx`, and
|
||||
`joiner.onnx`. You can use `onnx.utils.Extractor` to extract them.
|
||||
|
||||
(3) Export `model.state_dict()`
|
||||
|
||||
./pruned_transducer_stateless3/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
It will generate a file exp_dir/pretrained.pt
|
||||
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
||||
load it by `icefall.checkpoint.load_checkpoint()`.
|
||||
|
||||
To use the generated file with `pruned_transducer_stateless3/decode.py`,
|
||||
you can do:
|
||||
@ -46,10 +85,12 @@ you can do:
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import onnx
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
@ -152,6 +193,150 @@ def get_parser():
|
||||
return parser
|
||||
|
||||
|
||||
def export_encoder_model_onnx(
|
||||
encoder_model: nn.Module,
|
||||
encoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the given encoder model to ONNX format.
|
||||
The exported model has two inputs:
|
||||
|
||||
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
||||
- x_lens, a tensor of shape (N,); dtype is torch.int64
|
||||
|
||||
and it has two outputs:
|
||||
|
||||
- encoder_out, a tensor of shape (N, T, C)
|
||||
- encoder_out_lens, a tensor of shape (N,)
|
||||
|
||||
Note: The warmup argument is fixed to 1.
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
The input encoder model
|
||||
encoder_filename:
|
||||
The filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||
|
||||
# encoder_model = torch.jit.script(model.encoder)
|
||||
# It throws the following error for the above statement
|
||||
#
|
||||
# RuntimeError: Exporting the operator __is_ to ONNX opset version
|
||||
# 11 is not supported. Please feel free to request support or
|
||||
# submit a pull request on PyTorch GitHub.
|
||||
#
|
||||
# I cannot find which statement causes the above error.
|
||||
# torch.onnx.export() will use torch.jit.trace() internally, which
|
||||
# works well for the current reworked model
|
||||
warmup = 1.0
|
||||
torch.onnx.export(
|
||||
encoder_model,
|
||||
(x, x_lens, warmup),
|
||||
encoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x", "x_lens", "warmup"],
|
||||
output_names=["encoder_out", "encoder_out_lens"],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "T"},
|
||||
"x_lens": {0: "N"},
|
||||
"encoder_out": {0: "N", 1: "T"},
|
||||
"encoder_out_lens": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {encoder_filename}")
|
||||
|
||||
|
||||
def export_decoder_model_onnx(
|
||||
decoder_model: nn.Module,
|
||||
decoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the decoder model to ONNX format.
|
||||
|
||||
The exported model has one input:
|
||||
|
||||
- y: a torch.int64 tensor of shape (N, 2)
|
||||
|
||||
and has one output:
|
||||
|
||||
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The decoder model to be exported.
|
||||
decoder_filename:
|
||||
Filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||
need_pad = False # Always False, so we can use torch.jit.trace() here
|
||||
# Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script()
|
||||
# in this case
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
(y, need_pad),
|
||||
decoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["y", "need_pad"],
|
||||
output_names=["decoder_out"],
|
||||
dynamic_axes={
|
||||
"y": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
|
||||
def export_joiner_model_onnx(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the joiner model to ONNX format.
|
||||
The exported model has two inputs:
|
||||
|
||||
- encoder_out: a tensor of shape (N, encoder_out_dim)
|
||||
- decoder_out: a tensor of shape (N, decoder_out_dim)
|
||||
|
||||
and has one output:
|
||||
|
||||
- joiner_out: a tensor of shape (N, vocab_size)
|
||||
|
||||
Note: The argument project_input is fixed to True. A user should not
|
||||
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
||||
will do that for the user.
|
||||
"""
|
||||
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
|
||||
project_input = True
|
||||
# Note: We use torch.jit.trace() here
|
||||
torch.onnx.export(
|
||||
joiner_model,
|
||||
(encoder_out, decoder_out, project_input),
|
||||
joiner_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["encoder_out", "decoder_out", "project_input"],
|
||||
output_names=["logit"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
"logit": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
@ -218,90 +403,50 @@ def main():
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
opset_version = 11
|
||||
|
||||
if params.onnx:
|
||||
if params.onnx is True:
|
||||
opset_version = 11
|
||||
logging.info("Exporting to onnx format")
|
||||
if True:
|
||||
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||
warmup = 1.0
|
||||
encoder_filename = params.exp_dir / "encoder.onnx"
|
||||
# encoder_model = torch.jit.script(model.encoder)
|
||||
# It throws the following error for the above statement
|
||||
#
|
||||
# RuntimeError: Exporting the operator __is_ to ONNX opset version
|
||||
# 11 is not supported. Please feel free to request support or
|
||||
# submit a pull request on PyTorch GitHub.
|
||||
#
|
||||
# I cannot find which statement causes the above error.
|
||||
# torch.onnx.export() will use torch.jit.trace() internally, which
|
||||
# works well for the current reworked model
|
||||
encoder_filename = params.exp_dir / "encoder.onnx"
|
||||
export_encoder_model_onnx(
|
||||
model.encoder,
|
||||
encoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
encoder_model = model.encoder
|
||||
decoder_filename = params.exp_dir / "decoder.onnx"
|
||||
export_decoder_model_onnx(
|
||||
model.decoder,
|
||||
decoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
torch.onnx.export(
|
||||
encoder_model,
|
||||
(x, x_lens, warmup),
|
||||
encoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x", "x_lens", "warmup"],
|
||||
output_names=["encoder_out", "encoder_out_lens"],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "T"},
|
||||
"x_lens": {0: "N"},
|
||||
"encoder_out": {0: "N", 1: "T"},
|
||||
"encoder_out_lens": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {encoder_filename}")
|
||||
joiner_filename = params.exp_dir / "joiner.onnx"
|
||||
export_joiner_model_onnx(
|
||||
model.joiner,
|
||||
joiner_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
if True:
|
||||
y = torch.zeros(10, 2, dtype=torch.int64)
|
||||
need_pad = False
|
||||
decoder_filename = params.exp_dir / "decoder.onnx"
|
||||
decoder_model = torch.jit.script(model.decoder)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
(y, need_pad),
|
||||
decoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["y", "need_pad"],
|
||||
output_names=["decoder_out"],
|
||||
dynamic_axes={
|
||||
"y": {0: "N", 1: "U"},
|
||||
"decoder_out": {0: "N", 1: "U"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
all_in_one_filename = params.exp_dir / "all_in_one.onnx"
|
||||
encoder_onnx = onnx.load(encoder_filename)
|
||||
decoder_onnx = onnx.load(decoder_filename)
|
||||
joiner_onnx = onnx.load(joiner_filename)
|
||||
|
||||
if True:
|
||||
encoder_out = torch.rand(1, 1, 3, 512, dtype=torch.float32)
|
||||
decoder_out = torch.rand(1, 1, 3, 512, dtype=torch.float32)
|
||||
project_input = False
|
||||
joiner_filename = params.exp_dir / "joiner.onnx"
|
||||
joiner_model = torch.jit.script(model.joiner)
|
||||
torch.onnx.export(
|
||||
joiner_model,
|
||||
(encoder_out, decoder_out, project_input),
|
||||
joiner_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["encoder_out", "decoder_out", "project_input"],
|
||||
output_names=["logit"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N", 1: "T", 2: "s_range"},
|
||||
"decoder_out": {0: "N", 1: "T", 2: "s_range"},
|
||||
"logit": {0: "N", 1: "T", 2: "s_range"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
encoder_onnx = onnx.compose.add_prefix(encoder_onnx, prefix="encoder/")
|
||||
decoder_onnx = onnx.compose.add_prefix(decoder_onnx, prefix="decoder/")
|
||||
joiner_onnx = onnx.compose.add_prefix(joiner_onnx, prefix="joiner/")
|
||||
|
||||
return
|
||||
combined_model = onnx.compose.merge_models(
|
||||
encoder_onnx, decoder_onnx, io_map={}
|
||||
)
|
||||
combined_model = onnx.compose.merge_models(
|
||||
combined_model, joiner_onnx, io_map={}
|
||||
)
|
||||
onnx.save(combined_model, all_in_one_filename)
|
||||
logging.info(f"Saved to {all_in_one_filename}")
|
||||
|
||||
if params.jit:
|
||||
elif params.jit is True:
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
|
||||
@ -39,8 +39,8 @@ def get_parser():
|
||||
|
||||
parser.add_argument(
|
||||
"--jit-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the torchscript model",
|
||||
)
|
||||
|
||||
@ -53,12 +53,14 @@ def get_parser():
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-decoder-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the onnx decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-joiner-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the onnx joiner model",
|
||||
)
|
||||
@ -76,21 +78,25 @@ def test_encoder(
|
||||
assert encoder_inputs[0].shape == ["N", "T", 80]
|
||||
assert encoder_inputs[1].shape == ["N"]
|
||||
|
||||
x = torch.rand(5, 50, 80, dtype=torch.float32)
|
||||
x_lens = torch.tensor([50, 50, 20, 30, 10])
|
||||
for N in [1, 5]:
|
||||
for T in [12, 25]:
|
||||
print("N, T", N, T)
|
||||
x = torch.rand(N, T, 80, dtype=torch.float32)
|
||||
x_lens = torch.randint(low=10, high=T + 1, size=(N,))
|
||||
x_lens[0] = T
|
||||
|
||||
encoder_inputs = {"x": x.numpy(), "x_lens": x_lens.numpy()}
|
||||
encoder_out, encoder_out_lens = encoder_session.run(
|
||||
["encoder_out", "encoder_out_lens"],
|
||||
encoder_inputs,
|
||||
)
|
||||
encoder_inputs = {"x": x.numpy(), "x_lens": x_lens.numpy()}
|
||||
encoder_out, encoder_out_lens = encoder_session.run(
|
||||
["encoder_out", "encoder_out_lens"],
|
||||
encoder_inputs,
|
||||
)
|
||||
|
||||
torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens)
|
||||
torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens)
|
||||
|
||||
encoder_out = torch.from_numpy(encoder_out)
|
||||
assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), (
|
||||
(encoder_out - torch_encoder_out).abs().max()
|
||||
)
|
||||
encoder_out = torch.from_numpy(encoder_out)
|
||||
assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), (
|
||||
(encoder_out - torch_encoder_out).abs().max()
|
||||
)
|
||||
|
||||
|
||||
def test_decoder(
|
||||
@ -99,19 +105,53 @@ def test_decoder(
|
||||
):
|
||||
decoder_inputs = decoder_session.get_inputs()
|
||||
assert decoder_inputs[0].name == "y"
|
||||
assert decoder_inputs[1].name == "need_pad"
|
||||
assert decoder_inputs[0].shape == ["N", "U"]
|
||||
y = torch.randint(low=1, high=500, size=(1, 2))
|
||||
assert decoder_inputs[0].shape == ["N", 2]
|
||||
for N in [1, 5, 10]:
|
||||
y = torch.randint(low=1, high=500, size=(10, 2))
|
||||
|
||||
decoder_inputs = {"y": y.numpy(), "need_pad": np.array([False], dtype=bool)}
|
||||
decoder_out = decoder_session.run(
|
||||
["decoder_out"],
|
||||
decoder_inputs,
|
||||
)[0]
|
||||
decoder_out = torch.from_numpy(decoder_out)
|
||||
decoder_inputs = {"y": y.numpy()}
|
||||
decoder_out = decoder_session.run(
|
||||
["decoder_out"],
|
||||
decoder_inputs,
|
||||
)[0]
|
||||
decoder_out = torch.from_numpy(decoder_out)
|
||||
|
||||
torch_decoder_out = model.decoder(y, need_pad=False)
|
||||
assert torch.allclose(decoder_out, torch_decoder_out)
|
||||
torch_decoder_out = model.decoder(y, need_pad=False)
|
||||
assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), (
|
||||
(decoder_out - torch_decoder_out).abs().max()
|
||||
)
|
||||
|
||||
|
||||
def test_joiner(
|
||||
model: torch.jit.ScriptModule,
|
||||
joiner_session: ort.InferenceSession,
|
||||
):
|
||||
joiner_inputs = joiner_session.get_inputs()
|
||||
assert joiner_inputs[0].name == "encoder_out"
|
||||
assert joiner_inputs[0].shape == ["N", 512]
|
||||
|
||||
assert joiner_inputs[1].name == "decoder_out"
|
||||
assert joiner_inputs[1].shape == ["N", 512]
|
||||
|
||||
for N in [1, 5, 10]:
|
||||
encoder_out = torch.rand(N, 512)
|
||||
decoder_out = torch.rand(N, 512)
|
||||
|
||||
joiner_inputs = {
|
||||
"encoder_out": encoder_out.numpy(),
|
||||
"decoder_out": decoder_out.numpy(),
|
||||
}
|
||||
decoder_out = joiner_session.run(["logit"], joiner_inputs)[0]
|
||||
joiner_out = torch.from_numpy(joiner_out)
|
||||
|
||||
torch_joiner_out = model.joiner(
|
||||
encoder_out,
|
||||
decoder_out,
|
||||
project_input=True,
|
||||
)
|
||||
assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), (
|
||||
(joiner_out - torch_joiner_out).abs().max()
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@ -139,6 +179,13 @@ def main():
|
||||
)
|
||||
test_decoder(model, decoder_session)
|
||||
|
||||
logging.info("Test joiner")
|
||||
joiner_session = ort.InferenceSession(
|
||||
args.onnx_joiner_filename,
|
||||
sess_options=options,
|
||||
)
|
||||
test_joiner(model, joiner_session)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20220727)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user