From 3ebb52aa9b5872432641f637af43716315a47767 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 29 Jul 2022 21:14:20 +0800 Subject: [PATCH] Combine encoder/decoder/joiner into a single file. --- .../pruned_transducer_stateless3/export.py | 299 +++++++++++++----- .../onnx_check.py | 95 ++++-- 2 files changed, 293 insertions(+), 101 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index 1fa39ceb7..bd1dc0e20 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -20,13 +20,52 @@ # to a single one using model averaging. """ Usage: + +(1) Export to torchscript model + +./pruned_transducer_stateless3/export.py \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit 1 + +It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("cpu_jit.pt")`. + +Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python +are on CPU. You can use `to("cuda")` to move it to a CUDA device. + +(2) Export to ONNX format + +./pruned_transducer_stateless3/export.py \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --onnx 1 + +It will generate the following files in the given `exp_dir`. +See `onnx_check.py` to see how to use it. + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - all_in_one.onnx + +The file all_in_one.onnx combines `encoder.onnx`, `decoder.onnx`, and +`joiner.onnx`. You can use `onnx.utils.Extractor` to extract them. + +(3) Export `model.state_dict()` + ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ --bpe-model data/lang_bpe_500/bpe.model \ --epoch 20 \ --avg 10 -It will generate a file exp_dir/pretrained.pt +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. To use the generated file with `pruned_transducer_stateless3/decode.py`, you can do: @@ -46,10 +85,12 @@ you can do: import argparse import logging +import onnx from pathlib import Path import sentencepiece as spm import torch +import torch.nn as nn from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( @@ -152,6 +193,150 @@ def get_parser(): return parser +def export_encoder_model_onnx( + encoder_model: nn.Module, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T, C) + - encoder_out_lens, a tensor of shape (N,) + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + # encoder_model = torch.jit.script(model.encoder) + # It throws the following error for the above statement + # + # RuntimeError: Exporting the operator __is_ to ONNX opset version + # 11 is not supported. Please feel free to request support or + # submit a pull request on PyTorch GitHub. + # + # I cannot find which statement causes the above error. + # torch.onnx.export() will use torch.jit.trace() internally, which + # works well for the current reworked model + warmup = 1.0 + torch.onnx.export( + encoder_model, + (x, x_lens, warmup), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens", "warmup"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_onnx( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, 2) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, 1, C) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = False # Always False, so we can use torch.jit.trace() here + # Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script() + # in this case + torch.onnx.export( + decoder_model, + (y, need_pad), + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y", "need_pad"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported model has two inputs: + + - encoder_out: a tensor of shape (N, encoder_out_dim) + - decoder_out: a tensor of shape (N, decoder_out_dim) + + and has one output: + + - joiner_out: a tensor of shape (N, vocab_size) + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + project_input = True + # Note: We use torch.jit.trace() here + torch.onnx.export( + joiner_model, + (encoder_out, decoder_out, project_input), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=["encoder_out", "decoder_out", "project_input"], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + logging.info(f"Saved to {joiner_filename}") + + @torch.no_grad() def main(): args = get_parser().parse_args() @@ -218,90 +403,50 @@ def main(): model.to("cpu") model.eval() - opset_version = 11 - if params.onnx: + if params.onnx is True: + opset_version = 11 logging.info("Exporting to onnx format") - if True: - x = torch.zeros(1, 100, 80, dtype=torch.float32) - x_lens = torch.tensor([100], dtype=torch.int64) - warmup = 1.0 - encoder_filename = params.exp_dir / "encoder.onnx" - # encoder_model = torch.jit.script(model.encoder) - # It throws the following error for the above statement - # - # RuntimeError: Exporting the operator __is_ to ONNX opset version - # 11 is not supported. Please feel free to request support or - # submit a pull request on PyTorch GitHub. - # - # I cannot find which statement causes the above error. - # torch.onnx.export() will use torch.jit.trace() internally, which - # works well for the current reworked model + encoder_filename = params.exp_dir / "encoder.onnx" + export_encoder_model_onnx( + model.encoder, + encoder_filename, + opset_version=opset_version, + ) - encoder_model = model.encoder + decoder_filename = params.exp_dir / "decoder.onnx" + export_decoder_model_onnx( + model.decoder, + decoder_filename, + opset_version=opset_version, + ) - torch.onnx.export( - encoder_model, - (x, x_lens, warmup), - encoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["x", "x_lens", "warmup"], - output_names=["encoder_out", "encoder_out_lens"], - dynamic_axes={ - "x": {0: "N", 1: "T"}, - "x_lens": {0: "N"}, - "encoder_out": {0: "N", 1: "T"}, - "encoder_out_lens": {0: "N"}, - }, - ) - logging.info(f"Saved to {encoder_filename}") + joiner_filename = params.exp_dir / "joiner.onnx" + export_joiner_model_onnx( + model.joiner, + joiner_filename, + opset_version=opset_version, + ) - if True: - y = torch.zeros(10, 2, dtype=torch.int64) - need_pad = False - decoder_filename = params.exp_dir / "decoder.onnx" - decoder_model = torch.jit.script(model.decoder) - torch.onnx.export( - decoder_model, - (y, need_pad), - decoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["y", "need_pad"], - output_names=["decoder_out"], - dynamic_axes={ - "y": {0: "N", 1: "U"}, - "decoder_out": {0: "N", 1: "U"}, - }, - ) - logging.info(f"Saved to {decoder_filename}") + all_in_one_filename = params.exp_dir / "all_in_one.onnx" + encoder_onnx = onnx.load(encoder_filename) + decoder_onnx = onnx.load(decoder_filename) + joiner_onnx = onnx.load(joiner_filename) - if True: - encoder_out = torch.rand(1, 1, 3, 512, dtype=torch.float32) - decoder_out = torch.rand(1, 1, 3, 512, dtype=torch.float32) - project_input = False - joiner_filename = params.exp_dir / "joiner.onnx" - joiner_model = torch.jit.script(model.joiner) - torch.onnx.export( - joiner_model, - (encoder_out, decoder_out, project_input), - joiner_filename, - verbose=False, - opset_version=opset_version, - input_names=["encoder_out", "decoder_out", "project_input"], - output_names=["logit"], - dynamic_axes={ - "encoder_out": {0: "N", 1: "T", 2: "s_range"}, - "decoder_out": {0: "N", 1: "T", 2: "s_range"}, - "logit": {0: "N", 1: "T", 2: "s_range"}, - }, - ) - logging.info(f"Saved to {joiner_filename}") + encoder_onnx = onnx.compose.add_prefix(encoder_onnx, prefix="encoder/") + decoder_onnx = onnx.compose.add_prefix(decoder_onnx, prefix="decoder/") + joiner_onnx = onnx.compose.add_prefix(joiner_onnx, prefix="joiner/") - return + combined_model = onnx.compose.merge_models( + encoder_onnx, decoder_onnx, io_map={} + ) + combined_model = onnx.compose.merge_models( + combined_model, joiner_onnx, io_map={} + ) + onnx.save(combined_model, all_in_one_filename) + logging.info(f"Saved to {all_in_one_filename}") - if params.jit: + elif params.jit is True: # We won't use the forward() method of the model in C++, so just ignore # it here. # Otherwise, one of its arguments is a ragged tensor and is not diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py index d7379c22e..df0d0f09e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py @@ -39,8 +39,8 @@ def get_parser(): parser.add_argument( "--jit-filename", - type=str, required=True, + type=str, help="Path to the torchscript model", ) @@ -53,12 +53,14 @@ def get_parser(): parser.add_argument( "--onnx-decoder-filename", + required=True, type=str, help="Path to the onnx decoder model", ) parser.add_argument( "--onnx-joiner-filename", + required=True, type=str, help="Path to the onnx joiner model", ) @@ -76,21 +78,25 @@ def test_encoder( assert encoder_inputs[0].shape == ["N", "T", 80] assert encoder_inputs[1].shape == ["N"] - x = torch.rand(5, 50, 80, dtype=torch.float32) - x_lens = torch.tensor([50, 50, 20, 30, 10]) + for N in [1, 5]: + for T in [12, 25]: + print("N, T", N, T) + x = torch.rand(N, T, 80, dtype=torch.float32) + x_lens = torch.randint(low=10, high=T + 1, size=(N,)) + x_lens[0] = T - encoder_inputs = {"x": x.numpy(), "x_lens": x_lens.numpy()} - encoder_out, encoder_out_lens = encoder_session.run( - ["encoder_out", "encoder_out_lens"], - encoder_inputs, - ) + encoder_inputs = {"x": x.numpy(), "x_lens": x_lens.numpy()} + encoder_out, encoder_out_lens = encoder_session.run( + ["encoder_out", "encoder_out_lens"], + encoder_inputs, + ) - torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens) + torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens) - encoder_out = torch.from_numpy(encoder_out) - assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), ( - (encoder_out - torch_encoder_out).abs().max() - ) + encoder_out = torch.from_numpy(encoder_out) + assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), ( + (encoder_out - torch_encoder_out).abs().max() + ) def test_decoder( @@ -99,19 +105,53 @@ def test_decoder( ): decoder_inputs = decoder_session.get_inputs() assert decoder_inputs[0].name == "y" - assert decoder_inputs[1].name == "need_pad" - assert decoder_inputs[0].shape == ["N", "U"] - y = torch.randint(low=1, high=500, size=(1, 2)) + assert decoder_inputs[0].shape == ["N", 2] + for N in [1, 5, 10]: + y = torch.randint(low=1, high=500, size=(10, 2)) - decoder_inputs = {"y": y.numpy(), "need_pad": np.array([False], dtype=bool)} - decoder_out = decoder_session.run( - ["decoder_out"], - decoder_inputs, - )[0] - decoder_out = torch.from_numpy(decoder_out) + decoder_inputs = {"y": y.numpy()} + decoder_out = decoder_session.run( + ["decoder_out"], + decoder_inputs, + )[0] + decoder_out = torch.from_numpy(decoder_out) - torch_decoder_out = model.decoder(y, need_pad=False) - assert torch.allclose(decoder_out, torch_decoder_out) + torch_decoder_out = model.decoder(y, need_pad=False) + assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), ( + (decoder_out - torch_decoder_out).abs().max() + ) + + +def test_joiner( + model: torch.jit.ScriptModule, + joiner_session: ort.InferenceSession, +): + joiner_inputs = joiner_session.get_inputs() + assert joiner_inputs[0].name == "encoder_out" + assert joiner_inputs[0].shape == ["N", 512] + + assert joiner_inputs[1].name == "decoder_out" + assert joiner_inputs[1].shape == ["N", 512] + + for N in [1, 5, 10]: + encoder_out = torch.rand(N, 512) + decoder_out = torch.rand(N, 512) + + joiner_inputs = { + "encoder_out": encoder_out.numpy(), + "decoder_out": decoder_out.numpy(), + } + decoder_out = joiner_session.run(["logit"], joiner_inputs)[0] + joiner_out = torch.from_numpy(joiner_out) + + torch_joiner_out = model.joiner( + encoder_out, + decoder_out, + project_input=True, + ) + assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), ( + (joiner_out - torch_joiner_out).abs().max() + ) @torch.no_grad() @@ -139,6 +179,13 @@ def main(): ) test_decoder(model, decoder_session) + logging.info("Test joiner") + joiner_session = ort.InferenceSession( + args.onnx_joiner_filename, + sess_options=options, + ) + test_joiner(model, joiner_session) + if __name__ == "__main__": torch.manual_seed(20220727)