mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
WIP: Support exporting to ONNX format
This commit is contained in:
parent
385645d533
commit
8c98599ded
@ -1010,6 +1010,32 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
n == left_context + 2 * time1 - 1
|
||||
), f"{n} == {left_context} + 2 * {time1} - 1"
|
||||
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
x = x.contiguous()
|
||||
b = x.size(0)
|
||||
h = x.size(1)
|
||||
t = x.size(2)
|
||||
c = x.size(3)
|
||||
|
||||
bh = b * h
|
||||
|
||||
if False:
|
||||
rows = torch.arange(start=t - 1, end=-1, step=-1).unsqueeze(-1)
|
||||
cols = torch.arange(t)
|
||||
indexes = rows + cols
|
||||
# onnx does not support torch.tile
|
||||
indexes = torch.tile(indexes, (bh, 1))
|
||||
else:
|
||||
rows = torch.arange(start=t - 1, end=-1, step=-1)
|
||||
cols = torch.arange(t)
|
||||
rows = torch.cat([rows] * bh).unsqueeze(-1)
|
||||
indexes = rows + cols
|
||||
|
||||
x = x.reshape(-1, c)
|
||||
x = torch.gather(x, dim=1, index=indexes)
|
||||
x = x.reshape(b, h, t, t)
|
||||
return x
|
||||
else:
|
||||
# Note: TorchScript requires explicit arg for stride()
|
||||
batch_stride = x.stride(0)
|
||||
head_stride = x.stride(1)
|
||||
|
@ -53,6 +53,7 @@ class Joiner(nn.Module):
|
||||
Return a tensor of shape (N, T, s_range, C).
|
||||
"""
|
||||
|
||||
if not torch.jit.is_scripting() or not torch.jit.is_tracing():
|
||||
assert encoder_out.ndim == decoder_out.ndim
|
||||
assert encoder_out.ndim in (2, 4)
|
||||
assert encoder_out.shape == decoder_out.shape
|
||||
|
@ -423,7 +423,7 @@ class ActivationBalancer(torch.nn.Module):
|
||||
self.max_abs = max_abs
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
return x
|
||||
else:
|
||||
return ActivationBalancerFunction.apply(
|
||||
@ -472,7 +472,7 @@ class DoubleSwish(torch.nn.Module):
|
||||
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
||||
that we approximate closely with x * sigmoid(x-1).
|
||||
"""
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
return x * torch.sigmoid(x - 1.0)
|
||||
else:
|
||||
return DoubleSwishFunction.apply(x)
|
||||
|
@ -117,6 +117,19 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""If True, --jit is ignored and it exports the model
|
||||
to onnx format. Three files will be generated:
|
||||
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
- joiner.onnx
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
@ -139,6 +152,7 @@ def get_parser():
|
||||
return parser
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
@ -165,7 +179,7 @@ def main():
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
model = get_transducer_model(params, enable_giga=False)
|
||||
|
||||
model.to(device)
|
||||
|
||||
@ -185,7 +199,9 @@ def main():
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device), strict=False
|
||||
)
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
@ -196,12 +212,83 @@ def main():
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
|
||||
model.eval()
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device), strict=False
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
opset_version = 11
|
||||
|
||||
if params.onnx:
|
||||
logging.info("Exporting to onnx format")
|
||||
if True:
|
||||
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||
warmup = 1.0
|
||||
encoder_filename = params.exp_dir / "encoder.onnx"
|
||||
# encoder_model = torch.jit.script(model.encoder)
|
||||
encoder_model = model.encoder
|
||||
torch.onnx.export(
|
||||
encoder_model,
|
||||
(x, x_lens, warmup),
|
||||
encoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x", "x_lens", "warmup"],
|
||||
output_names=["encoder_out", "encoder_out_lens"],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "T"},
|
||||
"x_lens": {0: "N"},
|
||||
"encoder_out": {0: "N", 1: "T"},
|
||||
"encoder_out_lens": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {encoder_filename}")
|
||||
|
||||
if True:
|
||||
y = torch.zeros(10, 2, dtype=torch.int64)
|
||||
need_pad = False
|
||||
decoder_filename = params.exp_dir / "decoder.onnx"
|
||||
decoder_model = torch.jit.script(model.decoder)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
(y, need_pad),
|
||||
decoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["y", "need_pad"],
|
||||
output_names=["decoder_out"],
|
||||
dynamic_axes={
|
||||
"y": {0: "N", 1: "U"},
|
||||
"decoder_out": {0: "N", 1: "U"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
if True:
|
||||
encoder_out = torch.rand(1, 1, 3, 512, dtype=torch.float32)
|
||||
decoder_out = torch.rand(1, 1, 3, 512, dtype=torch.float32)
|
||||
project_input = False
|
||||
joiner_filename = params.exp_dir / "joiner.onnx"
|
||||
joiner_model = torch.jit.script(model.joiner)
|
||||
torch.onnx.export(
|
||||
joiner_model,
|
||||
(encoder_out, decoder_out, project_input),
|
||||
joiner_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["encoder_out", "decoder_out", "project_input"],
|
||||
output_names=["logit"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N", 1: "T", 2: "s_range"},
|
||||
"decoder_out": {0: "N", 1: "T", 2: "s_range"},
|
||||
"logit": {0: "N", 1: "T", 2: "s_range"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
return
|
||||
|
||||
if params.jit:
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
|
150
egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py
Executable file
150
egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py
Executable file
@ -0,0 +1,150 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script checks that exported onnx models produces the same output
|
||||
as the given torchscript model for the same input.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import onnxruntime as ort
|
||||
|
||||
ort.set_default_logger_severity(3)
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the torchscript model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-encoder-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the onnx encoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-decoder-filename",
|
||||
type=str,
|
||||
help="Path to the onnx decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-joiner-filename",
|
||||
type=str,
|
||||
help="Path to the onnx joiner model",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def test_encoder(
|
||||
model: torch.jit.ScriptModule,
|
||||
encoder_session: ort.InferenceSession,
|
||||
):
|
||||
encoder_inputs = encoder_session.get_inputs()
|
||||
assert encoder_inputs[0].name == "x"
|
||||
assert encoder_inputs[1].name == "x_lens"
|
||||
assert encoder_inputs[0].shape == ["N", "T", 80]
|
||||
assert encoder_inputs[1].shape == ["N"]
|
||||
|
||||
x = torch.rand(1, 100, 80, dtype=torch.float32)
|
||||
x_lens = torch.tensor([100])
|
||||
|
||||
encoder_inputs = {"x": x.numpy(), "x_lens": x_lens.numpy()}
|
||||
encoder_out, encoder_out_lens = encoder_session.run(
|
||||
["encoder_out", "encoder_out_lens"],
|
||||
encoder_inputs,
|
||||
)
|
||||
|
||||
torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens)
|
||||
|
||||
encoder_out = torch.from_numpy(encoder_out)
|
||||
assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), (
|
||||
(encoder_out - torch_encoder_out).abs().max()
|
||||
)
|
||||
|
||||
|
||||
def test_decoder(
|
||||
model: torch.jit.ScriptModule,
|
||||
decoder_session: ort.InferenceSession,
|
||||
):
|
||||
decoder_inputs = decoder_session.get_inputs()
|
||||
assert decoder_inputs[0].name == "y"
|
||||
assert decoder_inputs[1].name == "need_pad"
|
||||
assert decoder_inputs[0].shape == ["N", "U"]
|
||||
y = torch.randint(low=1, high=500, size=(1, 2))
|
||||
|
||||
decoder_inputs = {"y": y.numpy(), "need_pad": np.array([False], dtype=bool)}
|
||||
decoder_out = decoder_session.run(
|
||||
["decoder_out"],
|
||||
decoder_inputs,
|
||||
)[0]
|
||||
decoder_out = torch.from_numpy(decoder_out)
|
||||
|
||||
torch_decoder_out = model.decoder(y, need_pad=False)
|
||||
assert torch.allclose(decoder_out, torch_decoder_out)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
model = torch.jit.load(args.jit_filename)
|
||||
|
||||
options = ort.SessionOptions()
|
||||
options.inter_op_num_threads = 1
|
||||
options.intra_op_num_threads = 1
|
||||
|
||||
logging.info("Test encoder")
|
||||
encoder_session = ort.InferenceSession(
|
||||
args.onnx_encoder_filename,
|
||||
sess_options=options,
|
||||
)
|
||||
test_encoder(model, encoder_session)
|
||||
|
||||
logging.info("Test decoder")
|
||||
decoder_session = ort.InferenceSession(
|
||||
args.onnx_decoder_filename,
|
||||
sess_options=options,
|
||||
)
|
||||
test_decoder(model, decoder_session)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20220727)
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -436,13 +436,22 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
return joiner
|
||||
|
||||
|
||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
def get_transducer_model(
|
||||
params: AttributeDict,
|
||||
enable_giga: bool = True,
|
||||
) -> nn.Module:
|
||||
encoder = get_encoder_model(params)
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
|
||||
if enable_giga:
|
||||
logging.info("Use giga")
|
||||
decoder_giga = get_decoder_model(params)
|
||||
joiner_giga = get_joiner_model(params)
|
||||
else:
|
||||
logging.info("Disable giga")
|
||||
decoder_giga = None
|
||||
joiner_giga = None
|
||||
|
||||
model = Transducer(
|
||||
encoder=encoder,
|
||||
|
Loading…
x
Reference in New Issue
Block a user