WIP: Support exporting to ONNX format

This commit is contained in:
Fangjun Kuang 2022-07-27 22:58:27 +08:00
parent 385645d533
commit 8c98599ded
6 changed files with 296 additions and 23 deletions

View File

@ -1010,16 +1010,42 @@ class RelPositionMultiheadAttention(nn.Module):
n == left_context + 2 * time1 - 1 n == left_context + 2 * time1 - 1
), f"{n} == {left_context} + 2 * {time1} - 1" ), f"{n} == {left_context} + 2 * {time1} - 1"
# Note: TorchScript requires explicit arg for stride() if torch.jit.is_scripting() or torch.jit.is_tracing():
batch_stride = x.stride(0) x = x.contiguous()
head_stride = x.stride(1) b = x.size(0)
time1_stride = x.stride(2) h = x.size(1)
n_stride = x.stride(3) t = x.size(2)
return x.as_strided( c = x.size(3)
(batch_size, num_heads, time1, time2),
(batch_stride, head_stride, time1_stride - n_stride, n_stride), bh = b * h
storage_offset=n_stride * (time1 - 1),
) if False:
rows = torch.arange(start=t - 1, end=-1, step=-1).unsqueeze(-1)
cols = torch.arange(t)
indexes = rows + cols
# onnx does not support torch.tile
indexes = torch.tile(indexes, (bh, 1))
else:
rows = torch.arange(start=t - 1, end=-1, step=-1)
cols = torch.arange(t)
rows = torch.cat([rows] * bh).unsqueeze(-1)
indexes = rows + cols
x = x.reshape(-1, c)
x = torch.gather(x, dim=1, index=indexes)
x = x.reshape(b, h, t, t)
return x
else:
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)
time1_stride = x.stride(2)
n_stride = x.stride(3)
return x.as_strided(
(batch_size, num_heads, time1, time2),
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
storage_offset=n_stride * (time1 - 1),
)
def multi_head_attention_forward( def multi_head_attention_forward(
self, self,

View File

@ -53,9 +53,10 @@ class Joiner(nn.Module):
Return a tensor of shape (N, T, s_range, C). Return a tensor of shape (N, T, s_range, C).
""" """
assert encoder_out.ndim == decoder_out.ndim if not torch.jit.is_scripting() or not torch.jit.is_tracing():
assert encoder_out.ndim in (2, 4) assert encoder_out.ndim == decoder_out.ndim
assert encoder_out.shape == decoder_out.shape assert encoder_out.ndim in (2, 4)
assert encoder_out.shape == decoder_out.shape
if project_input: if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj( logit = self.encoder_proj(encoder_out) + self.decoder_proj(

View File

@ -423,7 +423,7 @@ class ActivationBalancer(torch.nn.Module):
self.max_abs = max_abs self.max_abs = max_abs
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting() or torch.jit.is_tracing():
return x return x
else: else:
return ActivationBalancerFunction.apply( return ActivationBalancerFunction.apply(
@ -472,7 +472,7 @@ class DoubleSwish(torch.nn.Module):
"""Return double-swish activation function which is an approximation to Swish(Swish(x)), """Return double-swish activation function which is an approximation to Swish(Swish(x)),
that we approximate closely with x * sigmoid(x-1). that we approximate closely with x * sigmoid(x-1).
""" """
if torch.jit.is_scripting(): if torch.jit.is_scripting() or torch.jit.is_tracing():
return x * torch.sigmoid(x - 1.0) return x * torch.sigmoid(x - 1.0)
else: else:
return DoubleSwishFunction.apply(x) return DoubleSwishFunction.apply(x)

View File

@ -117,6 +117,19 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--onnx",
type=str2bool,
default=False,
help="""If True, --jit is ignored and it exports the model
to onnx format. Three files will be generated:
- encoder.onnx
- decoder.onnx
- joiner.onnx
""",
)
parser.add_argument( parser.add_argument(
"--context-size", "--context-size",
type=int, type=int,
@ -139,6 +152,7 @@ def get_parser():
return parser return parser
@torch.no_grad()
def main(): def main():
args = get_parser().parse_args() args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
@ -165,7 +179,7 @@ def main():
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_transducer_model(params, enable_giga=False)
model.to(device) model.to(device)
@ -185,7 +199,9 @@ def main():
) )
logging.info(f"averaging {filenames}") logging.info(f"averaging {filenames}")
model.to(device) model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
elif params.avg == 1: elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else: else:
@ -196,12 +212,83 @@ def main():
filenames.append(f"{params.exp_dir}/epoch-{i}.pt") filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}") logging.info(f"averaging {filenames}")
model.to(device) model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
model.eval() )
model.to("cpu") model.to("cpu")
model.eval() model.eval()
opset_version = 11
if params.onnx:
logging.info("Exporting to onnx format")
if True:
x = torch.zeros(1, 100, 80, dtype=torch.float32)
x_lens = torch.tensor([100], dtype=torch.int64)
warmup = 1.0
encoder_filename = params.exp_dir / "encoder.onnx"
# encoder_model = torch.jit.script(model.encoder)
encoder_model = model.encoder
torch.onnx.export(
encoder_model,
(x, x_lens, warmup),
encoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "x_lens", "warmup"],
output_names=["encoder_out", "encoder_out_lens"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"encoder_out": {0: "N", 1: "T"},
"encoder_out_lens": {0: "N"},
},
)
logging.info(f"Saved to {encoder_filename}")
if True:
y = torch.zeros(10, 2, dtype=torch.int64)
need_pad = False
decoder_filename = params.exp_dir / "decoder.onnx"
decoder_model = torch.jit.script(model.decoder)
torch.onnx.export(
decoder_model,
(y, need_pad),
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y", "need_pad"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N", 1: "U"},
"decoder_out": {0: "N", 1: "U"},
},
)
logging.info(f"Saved to {decoder_filename}")
if True:
encoder_out = torch.rand(1, 1, 3, 512, dtype=torch.float32)
decoder_out = torch.rand(1, 1, 3, 512, dtype=torch.float32)
project_input = False
joiner_filename = params.exp_dir / "joiner.onnx"
joiner_model = torch.jit.script(model.joiner)
torch.onnx.export(
joiner_model,
(encoder_out, decoder_out, project_input),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=["encoder_out", "decoder_out", "project_input"],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N", 1: "T", 2: "s_range"},
"decoder_out": {0: "N", 1: "T", 2: "s_range"},
"logit": {0: "N", 1: "T", 2: "s_range"},
},
)
logging.info(f"Saved to {joiner_filename}")
return
if params.jit: if params.jit:
# We won't use the forward() method of the model in C++, so just ignore # We won't use the forward() method of the model in C++, so just ignore

View File

@ -0,0 +1,150 @@
#!/usr/bin/env python3
#
# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script checks that exported onnx models produces the same output
as the given torchscript model for the same input.
"""
import argparse
import logging
import onnxruntime as ort
ort.set_default_logger_severity(3)
import numpy as np
import torch
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--jit-filename",
type=str,
required=True,
help="Path to the torchscript model",
)
parser.add_argument(
"--onnx-encoder-filename",
required=True,
type=str,
help="Path to the onnx encoder model",
)
parser.add_argument(
"--onnx-decoder-filename",
type=str,
help="Path to the onnx decoder model",
)
parser.add_argument(
"--onnx-joiner-filename",
type=str,
help="Path to the onnx joiner model",
)
return parser
def test_encoder(
model: torch.jit.ScriptModule,
encoder_session: ort.InferenceSession,
):
encoder_inputs = encoder_session.get_inputs()
assert encoder_inputs[0].name == "x"
assert encoder_inputs[1].name == "x_lens"
assert encoder_inputs[0].shape == ["N", "T", 80]
assert encoder_inputs[1].shape == ["N"]
x = torch.rand(1, 100, 80, dtype=torch.float32)
x_lens = torch.tensor([100])
encoder_inputs = {"x": x.numpy(), "x_lens": x_lens.numpy()}
encoder_out, encoder_out_lens = encoder_session.run(
["encoder_out", "encoder_out_lens"],
encoder_inputs,
)
torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens)
encoder_out = torch.from_numpy(encoder_out)
assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), (
(encoder_out - torch_encoder_out).abs().max()
)
def test_decoder(
model: torch.jit.ScriptModule,
decoder_session: ort.InferenceSession,
):
decoder_inputs = decoder_session.get_inputs()
assert decoder_inputs[0].name == "y"
assert decoder_inputs[1].name == "need_pad"
assert decoder_inputs[0].shape == ["N", "U"]
y = torch.randint(low=1, high=500, size=(1, 2))
decoder_inputs = {"y": y.numpy(), "need_pad": np.array([False], dtype=bool)}
decoder_out = decoder_session.run(
["decoder_out"],
decoder_inputs,
)[0]
decoder_out = torch.from_numpy(decoder_out)
torch_decoder_out = model.decoder(y, need_pad=False)
assert torch.allclose(decoder_out, torch_decoder_out)
@torch.no_grad()
def main():
args = get_parser().parse_args()
logging.info(vars(args))
model = torch.jit.load(args.jit_filename)
options = ort.SessionOptions()
options.inter_op_num_threads = 1
options.intra_op_num_threads = 1
logging.info("Test encoder")
encoder_session = ort.InferenceSession(
args.onnx_encoder_filename,
sess_options=options,
)
test_encoder(model, encoder_session)
logging.info("Test decoder")
decoder_session = ort.InferenceSession(
args.onnx_decoder_filename,
sess_options=options,
)
test_decoder(model, decoder_session)
if __name__ == "__main__":
torch.manual_seed(20220727)
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -436,13 +436,22 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
return joiner return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module: def get_transducer_model(
params: AttributeDict,
enable_giga: bool = True,
) -> nn.Module:
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)
decoder_giga = get_decoder_model(params) if enable_giga:
joiner_giga = get_joiner_model(params) logging.info("Use giga")
decoder_giga = get_decoder_model(params)
joiner_giga = get_joiner_model(params)
else:
logging.info("Disable giga")
decoder_giga = None
joiner_giga = None
model = Transducer( model = Transducer(
encoder=encoder, encoder=encoder,