diff --git a/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh b/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh index a4a6cd8d7..bb7c7dfdc 100755 --- a/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh +++ b/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh @@ -4,6 +4,8 @@ # The computed features are saved to ~/tmp/fbank-libri and are # cached for later runs +set -e + export PYTHONPATH=$PWD:$PYTHONPATH echo $PYTHONPATH diff --git a/.github/scripts/download-gigaspeech-dev-test-dataset.sh b/.github/scripts/download-gigaspeech-dev-test-dataset.sh index b9464de9f..f3564efc7 100755 --- a/.github/scripts/download-gigaspeech-dev-test-dataset.sh +++ b/.github/scripts/download-gigaspeech-dev-test-dataset.sh @@ -6,6 +6,8 @@ # You will find directories `~/tmp/giga-dev-dataset-fbank` after running # this script. +set -e + mkdir -p ~/tmp cd ~/tmp diff --git a/.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh b/.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh index 3efcc13e3..11704526c 100755 --- a/.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh +++ b/.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh @@ -7,6 +7,8 @@ # You will find directories ~/tmp/download/LibriSpeech after running # this script. +set -e + mkdir ~/tmp/download cd egs/librispeech/ASR ln -s ~/tmp/download . diff --git a/.github/scripts/install-kaldifeat.sh b/.github/scripts/install-kaldifeat.sh index 6666a5064..de30f7dfe 100755 --- a/.github/scripts/install-kaldifeat.sh +++ b/.github/scripts/install-kaldifeat.sh @@ -3,6 +3,8 @@ # This script installs kaldifeat into the directory ~/tmp/kaldifeat # which is cached by GitHub actions for later runs. +set -e + mkdir -p ~/tmp cd ~/tmp git clone https://github.com/csukuangfj/kaldifeat diff --git a/.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh b/.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh index e0b87e0fc..1b48aae27 100755 --- a/.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh +++ b/.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh @@ -4,6 +4,8 @@ # to egs/librispeech/ASR/download/LibriSpeech and generates manifest # files in egs/librispeech/ASR/data/manifests +set -e + cd egs/librispeech/ASR [ ! -e download ] && ln -s ~/tmp/download . mkdir -p data/manifests diff --git a/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh b/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh index 631707ad9..aab2883a9 100755 --- a/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh +++ b/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} diff --git a/.github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh b/.github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh index 528d04cd1..c8d9c6b77 100755 --- a/.github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh +++ b/.github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} diff --git a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml index 19d606682..3d57a895c 100755 --- a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml +++ b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml @@ -1,4 +1,6 @@ #!/usr/bin/env bash +# +set -e log() { # This function is from espnet diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh index bd816c2d6..dafea56db 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh index 6b5b51bd7..d1e4a3991 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh index 62ea02c47..172d7ad4c 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh index 34dbdf44d..880767443 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} @@ -62,15 +64,13 @@ log "Decode with ONNX models" --onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \ --onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx -./pruned_transducer_stateless3/onnx_check_all_in_one.py \ - --jit-filename $repo/exp/cpu_jit.pt \ - --onnx-all-in-one-filename $repo/exp/all_in_one.onnx - ./pruned_transducer_stateless3/onnx_pretrained.py \ --bpe-model $repo/data/lang_bpe_500/bpe.model \ --encoder-model-filename $repo/exp/encoder.onnx \ --decoder-model-filename $repo/exp/decoder.onnx \ --joiner-model-filename $repo/exp/joiner.onnx \ + --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \ + --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh index c893bc45a..c6a781318 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} diff --git a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh index d9dc34e48..af37102d5 100755 --- a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh +++ b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} diff --git a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh index c22660d0a..5b8ed396b 100755 --- a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh +++ b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} diff --git a/.github/scripts/run-pre-trained-conformer-ctc.sh b/.github/scripts/run-pre-trained-conformer-ctc.sh index 96a072c46..6368b0bbd 100755 --- a/.github/scripts/run-pre-trained-conformer-ctc.sh +++ b/.github/scripts/run-pre-trained-conformer-ctc.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh index dcc99d62e..209d4814f 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh index 9622224c9..34ff76fe4 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} diff --git a/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh b/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh index 168aee766..75650c2d3 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} diff --git a/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh b/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh index 9211b22eb..bcc2d74cb 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} diff --git a/.github/scripts/run-pre-trained-transducer-stateless.sh b/.github/scripts/run-pre-trained-transducer-stateless.sh index 4a1dc1a7e..d3e40315a 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} diff --git a/.github/scripts/run-pre-trained-transducer.sh b/.github/scripts/run-pre-trained-transducer.sh index 5f8a5b3a5..cfa006776 100755 --- a/.github/scripts/run-pre-trained-transducer.sh +++ b/.github/scripts/run-pre-trained-transducer.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index c10678549..bce8a6bd1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -476,8 +476,8 @@ class ConformerEncoderLayer(nn.Module): self, src: Tensor, pos_emb: Tensor, - src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, + src_mask: Optional[Tensor] = None, warmup: float = 1.0, ) -> Tensor: """ @@ -486,8 +486,8 @@ class ConformerEncoderLayer(nn.Module): Args: src: the sequence to the encoder layer (required). pos_emb: Positional embedding tensor (required). - src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). + src_mask: the mask for the src sequence (optional). warmup: controls selective bypass of of layers; if < 1.0, we will bypass layers more frequently. Shape: @@ -663,8 +663,8 @@ class ConformerEncoder(nn.Module): self, src: Tensor, pos_emb: Tensor, - mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, + mask: Optional[Tensor] = None, warmup: float = 1.0, ) -> Tensor: r"""Pass the input through the encoder layers in turn. @@ -672,8 +672,8 @@ class ConformerEncoder(nn.Module): Args: src: the sequence to the encoder (required). pos_emb: Positional embedding tensor (required). - mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). + mask: the mask for the src sequence (optional). warmup: controls selective bypass of of layers; if < 1.0, we will bypass layers more frequently. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index 11f24244e..36c8d6611 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -62,7 +62,7 @@ It will generates 3 files: `encoder_jit_trace.pt`, --avg 10 \ --onnx 1 -It will generate the following six files in the given `exp_dir`. +It will generate the following files in the given `exp_dir`. Check `onnx_check.py` for how to use them. - encoder.onnx @@ -70,8 +70,8 @@ Check `onnx_check.py` for how to use them. - joiner.onnx - joiner_encoder_proj.onnx - joiner_decoder_proj.onnx - - all_in_one.onnx +Please see ./onnx_pretrained.py for usage of the generated files (4) Export `model.state_dict()` @@ -118,8 +118,6 @@ import argparse import logging from pathlib import Path -import onnx_graphsurgeon as gs -import onnx import sentencepiece as spm import torch import torch.nn as nn @@ -217,16 +215,15 @@ def get_parser(): type=str2bool, default=False, help="""If True, --jit is ignored and it exports the model - to onnx format. Three files will be generated: + to onnx format. It will generate the following files: - encoder.onnx - decoder.onnx - joiner.onnx - joiner_encoder_proj.onnx - joiner_decoder_proj.onnx - - all_in_one.onnx - Check ./onnx_check.py and ./onnx_pretrained.py for how to use them. + Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them. """, ) @@ -483,134 +480,99 @@ def export_joiner_model_onnx( opset_version: int = 11, ) -> None: """Export the joiner model to ONNX format. - The exported model has two inputs: + The exported joiner model has two inputs: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + - projected_decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + + The exported encoder_proj model has one input: - encoder_out: a tensor of shape (N, encoder_out_dim) + + and produces one output: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + + The exported decoder_proj model has one input: + - decoder_out: a tensor of shape (N, decoder_out_dim) - and has one output: - - - joiner_out: a tensor of shape (N, vocab_size) + and produces one output: + - projected_decoder_out: a tensor of shape (N, joiner_dim) """ + encoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_encoder_proj.onnx" + ) + + decoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_decoder_proj.onnx" + ) + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] - encoder_out = torch.rand(1, 1, 1, encoder_out_dim, dtype=torch.float32) - decoder_out = torch.rand(1, 1, 1, decoder_out_dim, dtype=torch.float32) + joiner_dim = joiner_model.decoder_proj.weight.shape[0] - project_input = True + projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) + + project_input = False # Note: It uses torch.jit.trace() internally torch.onnx.export( joiner_model, - (encoder_out, decoder_out, project_input), + (projected_encoder_out, projected_decoder_out, project_input), joiner_filename, verbose=False, opset_version=opset_version, - input_names=["encoder_out", "decoder_out", "project_input"], + input_names=[ + "projected_encoder_out", + "projected_decoder_out", + "project_input", + ], output_names=["logit"], dynamic_axes={ - "encoder_out": {0: "N"}, - "decoder_out": {0: "N"}, + "projected_encoder_out": {0: "N"}, + "projected_decoder_out": {0: "N"}, "logit": {0: "N"}, }, ) - torch.onnx.export( - joiner_model.encoder_proj, - (encoder_out.squeeze(0).squeeze(0)), - str(joiner_filename).replace(".onnx", "_encoder_proj.onnx"), - verbose=False, - opset_version=opset_version, - input_names=["encoder_out"], - output_names=["encoder_proj"], - dynamic_axes={ - "encoder_out": {0: "N"}, - "encoder_proj": {0: "N"}, - }, - ) - torch.onnx.export( - joiner_model.decoder_proj, - (decoder_out.squeeze(0).squeeze(0)), - str(joiner_filename).replace(".onnx", "_decoder_proj.onnx"), - verbose=False, - opset_version=opset_version, - input_names=["decoder_out"], - output_names=["decoder_proj"], - dynamic_axes={ - "decoder_out": {0: "N"}, - "decoder_proj": {0: "N"}, - }, - ) logging.info(f"Saved to {joiner_filename}") - -def add_variables( - model: nn.Module, combined_model: onnx.ModelProto -) -> onnx.ModelProto: - graph = gs.import_onnx(combined_model) - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - node = gs.Node( - op="Identity", - name="constants_lm", - attrs={ - "blank_id": blank_id, - "unk_id": unk_id, - "context_size": context_size, + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.encoder_proj, + encoder_out, + encoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["encoder_out"], + output_names=["projected_encoder_out"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "projected_encoder_out": {0: "N"}, }, - inputs=[], - outputs=[], ) - graph.nodes.append(node) + logging.info(f"Saved to {encoder_proj_filename}") - graph = gs.export_onnx(graph) - return graph - - -def export_all_in_one_onnx( - model: nn.Module, - encoder_filename: str, - decoder_filename: str, - joiner_filename: str, - all_in_one_filename: str, -): - encoder_onnx = onnx.load(encoder_filename) - decoder_onnx = onnx.load(decoder_filename) - joiner_onnx = onnx.load(joiner_filename) - joiner_encoder_proj_onnx = onnx.load( - str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.decoder_proj, + decoder_out, + decoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["decoder_out"], + output_names=["projected_decoder_out"], + dynamic_axes={ + "decoder_out": {0: "N"}, + "projected_decoder_out": {0: "N"}, + }, ) - joiner_decoder_proj_onnx = onnx.load( - str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") - ) - - encoder_onnx = onnx.compose.add_prefix(encoder_onnx, prefix="encoder/") - decoder_onnx = onnx.compose.add_prefix(decoder_onnx, prefix="decoder/") - joiner_onnx = onnx.compose.add_prefix(joiner_onnx, prefix="joiner/") - joiner_encoder_proj_onnx = onnx.compose.add_prefix( - joiner_encoder_proj_onnx, prefix="joiner_encoder_proj/" - ) - joiner_decoder_proj_onnx = onnx.compose.add_prefix( - joiner_decoder_proj_onnx, prefix="joiner_decoder_proj/" - ) - - combined_model = onnx.compose.merge_models( - encoder_onnx, decoder_onnx, io_map={} - ) - combined_model = onnx.compose.merge_models( - combined_model, joiner_onnx, io_map={} - ) - combined_model = onnx.compose.merge_models( - combined_model, joiner_encoder_proj_onnx, io_map={} - ) - combined_model = onnx.compose.merge_models( - combined_model, joiner_decoder_proj_onnx, io_map={} - ) - combined_model = add_variables(model, combined_model) - onnx.save(combined_model, all_in_one_filename) - logging.info(f"Saved to {all_in_one_filename}") + logging.info(f"Saved to {decoder_proj_filename}") @torch.no_grad() @@ -704,15 +666,6 @@ def main(): joiner_filename, opset_version=opset_version, ) - - all_in_one_filename = params.exp_dir / "all_in_one.onnx" - export_all_in_one_onnx( - model, - encoder_filename, - decoder_filename, - joiner_filename, - all_in_one_filename, - ) elif params.jit is True: convert_scaled_to_non_scaled(model, inplace=True) logging.info("Using torch.jit.script()") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py index a04b408d8..fb9adb44a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py @@ -84,11 +84,13 @@ def test_encoder( model: torch.jit.ScriptModule, encoder_session: ort.InferenceSession, ): - encoder_inputs = encoder_session.get_inputs() - assert encoder_inputs[0].name == "x" - assert encoder_inputs[1].name == "x_lens" - assert encoder_inputs[0].shape == ["N", "T", 80] - assert encoder_inputs[1].shape == ["N"] + inputs = encoder_session.get_inputs() + outputs = encoder_session.get_outputs() + input_names = [n.name for n in inputs] + output_names = [n.name for n in outputs] + + assert inputs[0].shape == ["N", "T", 80] + assert inputs[1].shape == ["N"] for N in [1, 5]: for T in [12, 25]: @@ -98,11 +100,11 @@ def test_encoder( x_lens[0] = T encoder_inputs = { - "x": x.numpy(), - "x_lens": x_lens.numpy(), + input_names[0]: x.numpy(), + input_names[1]: x_lens.numpy(), } encoder_out, encoder_out_lens = encoder_session.run( - ["encoder_out", "encoder_out_lens"], + output_names, encoder_inputs, ) @@ -110,7 +112,9 @@ def test_encoder( encoder_out = torch.from_numpy(encoder_out) assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), ( - (encoder_out - torch_encoder_out).abs().max() + (encoder_out - torch_encoder_out).abs().max(), + encoder_out.shape, + torch_encoder_out.shape, ) @@ -118,15 +122,18 @@ def test_decoder( model: torch.jit.ScriptModule, decoder_session: ort.InferenceSession, ): - decoder_inputs = decoder_session.get_inputs() - assert decoder_inputs[0].name == "y" - assert decoder_inputs[0].shape == ["N", 2] + inputs = decoder_session.get_inputs() + outputs = decoder_session.get_outputs() + input_names = [n.name for n in inputs] + output_names = [n.name for n in outputs] + + assert inputs[0].shape == ["N", 2] for N in [1, 5, 10]: y = torch.randint(low=1, high=500, size=(10, 2)) - decoder_inputs = {"y": y.numpy()} + decoder_inputs = {input_names[0]: y.numpy()} decoder_out = decoder_session.run( - ["decoder_out"], + output_names, decoder_inputs, )[0] decoder_out = torch.from_numpy(decoder_out) @@ -144,51 +151,62 @@ def test_joiner( joiner_decoder_proj_session: ort.InferenceSession, ): joiner_inputs = joiner_session.get_inputs() - assert joiner_inputs[0].name == "encoder_out" - assert joiner_inputs[0].shape == ["N", 1, 1, 512] + joiner_outputs = joiner_session.get_outputs() + joiner_input_names = [n.name for n in joiner_inputs] + joiner_output_names = [n.name for n in joiner_outputs] - assert joiner_inputs[1].name == "decoder_out" - assert joiner_inputs[1].shape == ["N", 1, 1, 512] + assert joiner_inputs[0].shape == ["N", 512] + assert joiner_inputs[1].shape == ["N", 512] joiner_encoder_proj_inputs = joiner_encoder_proj_session.get_inputs() - assert joiner_encoder_proj_inputs[0].name == "encoder_out" + encoder_proj_input_name = joiner_encoder_proj_inputs[0].name + assert joiner_encoder_proj_inputs[0].shape == ["N", 512] + joiner_encoder_proj_outputs = joiner_encoder_proj_session.get_outputs() + encoder_proj_output_name = joiner_encoder_proj_outputs[0].name + joiner_decoder_proj_inputs = joiner_decoder_proj_session.get_inputs() - assert joiner_decoder_proj_inputs[0].name == "decoder_out" + decoder_proj_input_name = joiner_decoder_proj_inputs[0].name + assert joiner_decoder_proj_inputs[0].shape == ["N", 512] + joiner_decoder_proj_outputs = joiner_decoder_proj_session.get_outputs() + decoder_proj_output_name = joiner_decoder_proj_outputs[0].name + for N in [1, 5, 10]: - encoder_out = torch.rand(N, 1, 1, 512) - decoder_out = torch.rand(N, 1, 1, 512) + encoder_out = torch.rand(N, 512) + decoder_out = torch.rand(N, 512) + + projected_encoder_out = torch.rand(N, 512) + projected_decoder_out = torch.rand(N, 512) joiner_inputs = { - "encoder_out": encoder_out.numpy(), - "decoder_out": decoder_out.numpy(), + joiner_input_names[0]: projected_encoder_out.numpy(), + joiner_input_names[1]: projected_decoder_out.numpy(), } - joiner_out = joiner_session.run(["logit"], joiner_inputs)[0] + joiner_out = joiner_session.run(joiner_output_names, joiner_inputs)[0] joiner_out = torch.from_numpy(joiner_out) torch_joiner_out = model.joiner( - encoder_out, - decoder_out, - project_input=True, + projected_encoder_out, + projected_decoder_out, + project_input=False, ) assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), ( (joiner_out - torch_joiner_out).abs().max() ) + # Now test encoder_proj joiner_encoder_proj_inputs = { - "encoder_out": encoder_out.squeeze(1).squeeze(1).numpy() + encoder_proj_input_name: encoder_out.numpy() } joiner_encoder_proj_out = joiner_encoder_proj_session.run( - ["encoder_proj"], joiner_encoder_proj_inputs + [encoder_proj_output_name], joiner_encoder_proj_inputs )[0] joiner_encoder_proj_out = torch.from_numpy(joiner_encoder_proj_out) - torch_joiner_encoder_proj_out = model.joiner.encoder_proj( - encoder_out.squeeze(1).squeeze(1) - ) + torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out) assert torch.allclose( joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5 ), ( @@ -197,17 +215,16 @@ def test_joiner( .max() ) + # Now test decoder_proj joiner_decoder_proj_inputs = { - "decoder_out": decoder_out.squeeze(1).squeeze(1).numpy() + decoder_proj_input_name: decoder_out.numpy() } joiner_decoder_proj_out = joiner_decoder_proj_session.run( - ["decoder_proj"], joiner_decoder_proj_inputs + [decoder_proj_output_name], joiner_decoder_proj_inputs )[0] joiner_decoder_proj_out = torch.from_numpy(joiner_decoder_proj_out) - torch_joiner_decoder_proj_out = model.joiner.decoder_proj( - decoder_out.squeeze(1).squeeze(1) - ) + torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out) assert torch.allclose( joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5 ), ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check_all_in_one.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check_all_in_one.py deleted file mode 100755 index b4cf8c94a..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check_all_in_one.py +++ /dev/null @@ -1,284 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2022 Xiaomi Corporation (Author: Yunus Emre Ozkose) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script checks that exported onnx models produce the same output -with the given torchscript model for the same input. -""" - -import argparse -import logging -import os - -import onnx -import onnx_graphsurgeon as gs -import onnxruntime -import onnxruntime as ort -import torch - -ort.set_default_logger_severity(3) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--jit-filename", - required=True, - type=str, - help="Path to the torchscript model", - ) - - parser.add_argument( - "--onnx-all-in-one-filename", - required=True, - type=str, - help="Path to the onnx all in one model", - ) - - return parser - - -def test_encoder( - model: torch.jit.ScriptModule, - encoder_session: ort.InferenceSession, -): - encoder_inputs = encoder_session.get_inputs() - assert encoder_inputs[0].shape == ["N", "T", 80] - assert encoder_inputs[1].shape == ["N"] - encoder_input_names = [i.name for i in encoder_inputs] - encoder_output_names = [i.name for i in encoder_session.get_outputs()] - - for N in [1, 5]: - for T in [12, 25]: - print("N, T", N, T) - x = torch.rand(N, T, 80, dtype=torch.float32) - x_lens = torch.randint(low=10, high=T + 1, size=(N,)) - x_lens[0] = T - - encoder_inputs = { - encoder_input_names[0]: x.numpy(), - encoder_input_names[1]: x_lens.numpy(), - } - encoder_out, encoder_out_lens = encoder_session.run( - [encoder_output_names[1], encoder_output_names[0]], - encoder_inputs, - ) - - torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens) - - encoder_out = torch.from_numpy(encoder_out) - assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), ( - (encoder_out - torch_encoder_out).abs().max() - ) - - -def test_decoder( - model: torch.jit.ScriptModule, - decoder_session: ort.InferenceSession, -): - decoder_inputs = decoder_session.get_inputs() - assert decoder_inputs[0].shape == ["N", 2] - decoder_input_names = [i.name for i in decoder_inputs] - decoder_output_names = [i.name for i in decoder_session.get_outputs()] - - for N in [1, 5, 10]: - y = torch.randint(low=1, high=500, size=(10, 2)) - - decoder_inputs = {decoder_input_names[0]: y.numpy()} - decoder_out = decoder_session.run( - [decoder_output_names[0]], - decoder_inputs, - )[0] - decoder_out = torch.from_numpy(decoder_out) - - torch_decoder_out = model.decoder(y, need_pad=False) - assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), ( - (decoder_out - torch_decoder_out).abs().max() - ) - - -def test_joiner( - model: torch.jit.ScriptModule, - joiner_session: ort.InferenceSession, -): - joiner_inputs = joiner_session.get_inputs() - assert joiner_inputs[0].shape == ["N", 512] - assert joiner_inputs[1].shape == ["N", 512] - joiner_input_names = [i.name for i in joiner_inputs] - joiner_output_names = [i.name for i in joiner_session.get_outputs()] - - for N in [1, 5, 10]: - encoder_out = torch.rand(N, 512) - decoder_out = torch.rand(N, 512) - - joiner_inputs = { - joiner_input_names[0]: encoder_out.numpy(), - joiner_input_names[1]: decoder_out.numpy(), - } - joiner_out = joiner_session.run( - [joiner_output_names[0]], joiner_inputs - )[0] - joiner_out = torch.from_numpy(joiner_out) - - torch_joiner_out = model.joiner( - encoder_out, - decoder_out, - project_input=True, - ) - assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), ( - (joiner_out - torch_joiner_out).abs().max() - ) - - -def extract_sub_model( - onnx_graph: onnx.ModelProto, - input_op_names: list, - output_op_names: list, - non_verbose=False, -): - onnx_graph = onnx.shape_inference.infer_shapes(onnx_graph) - graph = gs.import_onnx(onnx_graph) - graph.cleanup().toposort() - - # Extraction of input OP and output OP - graph_node_inputs = [ - graph_nodes - for graph_nodes in graph.nodes - for graph_nodes_input in graph_nodes.inputs - if graph_nodes_input.name in input_op_names - ] - graph_node_outputs = [ - graph_nodes - for graph_nodes in graph.nodes - for graph_nodes_output in graph_nodes.outputs - if graph_nodes_output.name in output_op_names - ] - - # Init graph INPUT/OUTPUT - graph.inputs.clear() - graph.outputs.clear() - - # Update graph INPUT/OUTPUT - graph.inputs = [ - graph_node_input - for graph_node in graph_node_inputs - for graph_node_input in graph_node.inputs - if graph_node_input.shape - ] - graph.outputs = [ - graph_node_output - for graph_node in graph_node_outputs - for graph_node_output in graph_node.outputs - ] - - # Cleanup - graph.cleanup().toposort() - - # Shape Estimation - extracted_graph = None - try: - extracted_graph = onnx.shape_inference.infer_shapes( - gs.export_onnx(graph) - ) - except Exception: - extracted_graph = gs.export_onnx(graph) - if not non_verbose: - print( - "WARNING: " - + "The input shape of the next OP does not match the output shape. " - + "Be sure to open the .onnx file to verify the certainty of the geometry." - ) - return extracted_graph - - -def extract_encoder(onnx_model: onnx.ModelProto): - encoder_ = extract_sub_model( - onnx_model, - ["encoder/x", "encoder/x_lens"], - ["encoder/encoder_out", "encoder/encoder_out_lens"], - False, - ) - onnx.save(encoder_, "tmp_encoder.onnx") - onnx.checker.check_model(encoder_) - sess = onnxruntime.InferenceSession("tmp_encoder.onnx") - os.remove("tmp_encoder.onnx") - return sess - - -def extract_decoder(onnx_model: onnx.ModelProto): - decoder_ = extract_sub_model( - onnx_model, ["decoder/y"], ["decoder/decoder_out"], False - ) - onnx.save(decoder_, "tmp_decoder.onnx") - onnx.checker.check_model(decoder_) - sess = onnxruntime.InferenceSession("tmp_decoder.onnx") - os.remove("tmp_decoder.onnx") - return sess - - -def extract_joiner(onnx_model: onnx.ModelProto): - joiner_ = extract_sub_model( - onnx_model, - ["joiner/encoder_out", "joiner/decoder_out"], - ["joiner/logit"], - False, - ) - onnx.save(joiner_, "tmp_joiner.onnx") - onnx.checker.check_model(joiner_) - sess = onnxruntime.InferenceSession("tmp_joiner.onnx") - os.remove("tmp_joiner.onnx") - return sess - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - logging.info(vars(args)) - - model = torch.jit.load(args.jit_filename) - onnx_model = onnx.load(args.onnx_all_in_one_filename) - - options = ort.SessionOptions() - options.inter_op_num_threads = 1 - options.intra_op_num_threads = 1 - - logging.info("Test encoder") - encoder_session = extract_encoder(onnx_model) - test_encoder(model, encoder_session) - - logging.info("Test decoder") - decoder_session = extract_decoder(onnx_model) - test_decoder(model, decoder_session) - - logging.info("Test joiner") - joiner_session = extract_joiner(onnx_model) - test_joiner(model, joiner_session) - logging.info("Finished checking ONNX models") - - -if __name__ == "__main__": - torch.manual_seed(20220727) - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py index 3e4a323aa..034217ad9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py @@ -59,21 +59,35 @@ def get_parser(): "--encoder-model-filename", type=str, required=True, - help="Path to the encoder torchscript model. ", + help="Path to the encoder onnx model. ", ) parser.add_argument( "--decoder-model-filename", type=str, required=True, - help="Path to the decoder torchscript model. ", + help="Path to the decoder onnx model. ", ) parser.add_argument( "--joiner-model-filename", type=str, required=True, - help="Path to the joiner torchscript model. ", + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--joiner-encoder-proj-model-filename", + type=str, + required=True, + help="Path to the joiner encoder_proj onnx model. ", + ) + + parser.add_argument( + "--joiner-decoder-proj-model-filename", + type=str, + required=True, + help="Path to the joiner decoder_proj onnx model. ", ) parser.add_argument( @@ -136,6 +150,8 @@ def read_sound_files( def greedy_search( decoder: ort.InferenceSession, joiner: ort.InferenceSession, + joiner_encoder_proj: ort.InferenceSession, + joiner_decoder_proj: ort.InferenceSession, encoder_out: np.ndarray, encoder_out_lens: np.ndarray, context_size: int, @@ -146,6 +162,10 @@ def greedy_search( The decoder model. joiner: The joiner model. + joiner_encoder_proj: + The joiner encoder projection model. + joiner_decoder_proj: + The joiner decoder projection model. encoder_out: A 3-D tensor of shape (N, T, C) encoder_out_lens: @@ -167,6 +187,15 @@ def greedy_search( enforce_sorted=False, ) + projected_encoder_out = joiner_encoder_proj.run( + [joiner_encoder_proj.get_outputs()[0].name], + { + joiner_encoder_proj.get_inputs()[ + 0 + ].name: packed_encoder_out.data.numpy() + }, + )[0] + blank_id = 0 # hard-code to 0 batch_size_list = packed_encoder_out.batch_sizes.tolist() @@ -194,30 +223,28 @@ def greedy_search( decoder_input_nodes[0].name: decoder_input.numpy(), }, )[0].squeeze(1) - decoder_out = torch.from_numpy(decoder_out) + projected_decoder_out = joiner_decoder_proj.run( + [joiner_decoder_proj.get_outputs()[0].name], + {joiner_decoder_proj.get_inputs()[0].name: decoder_out}, + )[0] + + projected_decoder_out = torch.from_numpy(projected_decoder_out) offset = 0 for batch_size in batch_size_list: start = offset end = offset + batch_size - current_encoder_out = packed_encoder_out.data[start:end] - current_encoder_out = current_encoder_out + current_encoder_out = projected_encoder_out[start:end] # current_encoder_out's shape: (batch_size, encoder_out_dim) offset = end - decoder_out = decoder_out[:batch_size] + projected_decoder_out = projected_decoder_out[:batch_size] logits = joiner.run( [joiner_output_nodes[0].name], { - joiner_input_nodes[0] - .name: current_encoder_out.unsqueeze(1) - .unsqueeze(1) - .numpy(), - joiner_input_nodes[1] - .name: decoder_out.unsqueeze(1) - .unsqueeze(1) - .numpy(), + joiner_input_nodes[0].name: current_encoder_out, + joiner_input_nodes[1].name: projected_decoder_out.numpy(), }, )[0] logits = torch.from_numpy(logits).squeeze(1).squeeze(1) @@ -243,7 +270,11 @@ def greedy_search( decoder_input_nodes[0].name: decoder_input.numpy(), }, )[0].squeeze(1) - decoder_out = torch.from_numpy(decoder_out) + projected_decoder_out = joiner_decoder_proj.run( + [joiner_decoder_proj.get_outputs()[0].name], + {joiner_decoder_proj.get_inputs()[0].name: decoder_out}, + )[0] + projected_decoder_out = torch.from_numpy(projected_decoder_out) sorted_ans = [h[context_size:] for h in hyps] ans = [] @@ -279,6 +310,16 @@ def main(): sess_options=session_opts, ) + joiner_encoder_proj = ort.InferenceSession( + args.joiner_encoder_proj_model_filename, + sess_options=session_opts, + ) + + joiner_decoder_proj = ort.InferenceSession( + args.joiner_decoder_proj_model_filename, + sess_options=session_opts, + ) + sp = spm.SentencePieceProcessor() sp.load(args.bpe_model) @@ -323,6 +364,8 @@ def main(): hyps = greedy_search( decoder=decoder, joiner=joiner, + joiner_encoder_proj=joiner_encoder_proj, + joiner_decoder_proj=joiner_decoder_proj, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, context_size=args.context_size, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py new file mode 100755 index 000000000..c55268b14 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py @@ -0,0 +1,401 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file is to test that models can be exported to onnx. +""" +import os + +import onnxruntime as ort +import torch +from conformer import ( + Conformer, + ConformerEncoder, + ConformerEncoderLayer, + Conv2dSubsampling, + RelPositionalEncoding, +) +from scaling_converter import convert_scaled_to_non_scaled + +from icefall.utils import make_pad_mask + +ort.set_default_logger_severity(3) + + +def test_conv2d_subsampling(): + filename = "conv2d_subsampling.onnx" + opset_version = 11 + N = 30 + T = 50 + num_features = 80 + d_model = 512 + x = torch.rand(N, T, num_features) + + encoder_embed = Conv2dSubsampling(num_features, d_model) + encoder_embed.eval() + encoder_embed = convert_scaled_to_non_scaled(encoder_embed, inplace=True) + + jit_model = torch.jit.trace(encoder_embed, x) + + torch.onnx.export( + encoder_embed, + x, + filename, + verbose=False, + opset_version=opset_version, + input_names=["x"], + output_names=["y"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "y": {0: "N", 1: "T"}, + }, + ) + + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + session = ort.InferenceSession( + filename, + sess_options=options, + ) + + input_nodes = session.get_inputs() + assert input_nodes[0].name == "x" + assert input_nodes[0].shape == ["N", "T", num_features] + + inputs = {input_nodes[0].name: x.numpy()} + + onnx_y = session.run(["y"], inputs)[0] + + onnx_y = torch.from_numpy(onnx_y) + torch_y = jit_model(x) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( + (onnx_y - torch_y).abs().max() + ) + + os.remove(filename) + + +def test_rel_pos(): + filename = "rel_pos.onnx" + + opset_version = 11 + N = 30 + T = 50 + num_features = 80 + d_model = 512 + x = torch.rand(N, T, num_features) + + encoder_pos = RelPositionalEncoding(d_model, dropout_rate=0.1) + encoder_pos.eval() + encoder_pos = convert_scaled_to_non_scaled(encoder_pos, inplace=True) + + jit_model = torch.jit.trace(encoder_pos, x) + + torch.onnx.export( + encoder_pos, + x, + filename, + verbose=False, + opset_version=opset_version, + input_names=["x"], + output_names=["y", "pos_emb"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "y": {0: "N", 1: "T"}, + "pos_emb": {0: "N", 1: "T"}, + }, + ) + + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + session = ort.InferenceSession( + filename, + sess_options=options, + ) + + input_nodes = session.get_inputs() + assert input_nodes[0].name == "x" + assert input_nodes[0].shape == ["N", "T", num_features] + + inputs = {input_nodes[0].name: x.numpy()} + onnx_y, onnx_pos_emb = session.run(["y", "pos_emb"], inputs) + onnx_y = torch.from_numpy(onnx_y) + onnx_pos_emb = torch.from_numpy(onnx_pos_emb) + + torch_y, torch_pos_emb = jit_model(x) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( + (onnx_y - torch_y).abs().max() + ) + + assert torch.allclose(onnx_pos_emb, torch_pos_emb, atol=1e-05), ( + (onnx_pos_emb - torch_pos_emb).abs().max() + ) + print(onnx_y.abs().sum(), torch_y.abs().sum()) + print(onnx_pos_emb.abs().sum(), torch_pos_emb.abs().sum()) + + os.remove(filename) + + +def test_conformer_encoder_layer(): + filename = "conformer_encoder_layer.onnx" + opset_version = 11 + N = 30 + T = 50 + + d_model = 512 + nhead = 8 + dim_feedforward = 2048 + dropout = 0.1 + layer_dropout = 0.075 + cnn_module_kernel = 31 + causal = False + + x = torch.rand(N, T, d_model) + x_lens = torch.full((N,), fill_value=T, dtype=torch.int64) + src_key_padding_mask = make_pad_mask(x_lens) + + encoder_pos = RelPositionalEncoding(d_model, dropout) + encoder_pos.eval() + encoder_pos = convert_scaled_to_non_scaled(encoder_pos, inplace=True) + + x, pos_emb = encoder_pos(x) + x = x.permute(1, 0, 2) + + encoder_layer = ConformerEncoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + layer_dropout, + cnn_module_kernel, + causal, + ) + encoder_layer.eval() + encoder_layer = convert_scaled_to_non_scaled(encoder_layer, inplace=True) + + jit_model = torch.jit.trace( + encoder_layer, (x, pos_emb, src_key_padding_mask) + ) + + torch.onnx.export( + encoder_layer, + (x, pos_emb, src_key_padding_mask), + filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "pos_emb", "src_key_padding_mask"], + output_names=["y"], + dynamic_axes={ + "x": {0: "T", 1: "N"}, + "pos_emb": {0: "N", 1: "T"}, + "src_key_padding_mask": {0: "N", 1: "T"}, + "y": {0: "T", 1: "N"}, + }, + ) + + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + session = ort.InferenceSession( + filename, + sess_options=options, + ) + + input_nodes = session.get_inputs() + inputs = { + input_nodes[0].name: x.numpy(), + input_nodes[1].name: pos_emb.numpy(), + input_nodes[2].name: src_key_padding_mask.numpy(), + } + onnx_y = session.run(["y"], inputs)[0] + onnx_y = torch.from_numpy(onnx_y) + + torch_y = jit_model(x, pos_emb, src_key_padding_mask) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( + (onnx_y - torch_y).abs().max() + ) + + print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) + + os.remove(filename) + + +def test_conformer_encoder(): + filename = "conformer_encoder.onnx" + + opset_version = 11 + N = 3 + T = 15 + + d_model = 512 + nhead = 8 + dim_feedforward = 2048 + dropout = 0.1 + layer_dropout = 0.075 + cnn_module_kernel = 31 + causal = False + num_encoder_layers = 12 + + x = torch.rand(N, T, d_model) + x_lens = torch.full((N,), fill_value=T, dtype=torch.int64) + src_key_padding_mask = make_pad_mask(x_lens) + + encoder_pos = RelPositionalEncoding(d_model, dropout) + encoder_pos.eval() + encoder_pos = convert_scaled_to_non_scaled(encoder_pos, inplace=True) + + x, pos_emb = encoder_pos(x) + x = x.permute(1, 0, 2) + + encoder_layer = ConformerEncoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + layer_dropout, + cnn_module_kernel, + causal, + ) + encoder = ConformerEncoder(encoder_layer, num_encoder_layers) + encoder.eval() + encoder = convert_scaled_to_non_scaled(encoder, inplace=True) + + jit_model = torch.jit.trace(encoder, (x, pos_emb, src_key_padding_mask)) + + torch.onnx.export( + encoder, + (x, pos_emb, src_key_padding_mask), + filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "pos_emb", "src_key_padding_mask"], + output_names=["y"], + dynamic_axes={ + "x": {0: "T", 1: "N"}, + "pos_emb": {0: "N", 1: "T"}, + "src_key_padding_mask": {0: "N", 1: "T"}, + "y": {0: "T", 1: "N"}, + }, + ) + + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + session = ort.InferenceSession( + filename, + sess_options=options, + ) + + input_nodes = session.get_inputs() + inputs = { + input_nodes[0].name: x.numpy(), + input_nodes[1].name: pos_emb.numpy(), + input_nodes[2].name: src_key_padding_mask.numpy(), + } + onnx_y = session.run(["y"], inputs)[0] + onnx_y = torch.from_numpy(onnx_y) + + torch_y = jit_model(x, pos_emb, src_key_padding_mask) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( + (onnx_y - torch_y).abs().max() + ) + + print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) + + os.remove(filename) + + +def test_conformer(): + filename = "conformer.onnx" + opset_version = 11 + N = 3 + T = 15 + num_features = 80 + x = torch.rand(N, T, num_features) + x_lens = torch.full((N,), fill_value=T, dtype=torch.int64) + + conformer = Conformer(num_features=num_features) + conformer.eval() + conformer = convert_scaled_to_non_scaled(conformer, inplace=True) + + jit_model = torch.jit.trace(conformer, (x, x_lens)) + torch.onnx.export( + conformer, + (x, x_lens), + filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["y", "y_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "y": {0: "N", 1: "T"}, + "y_lens": {0: "N"}, + }, + ) + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + session = ort.InferenceSession( + filename, + sess_options=options, + ) + + input_nodes = session.get_inputs() + inputs = { + input_nodes[0].name: x.numpy(), + input_nodes[1].name: x_lens.numpy(), + } + onnx_y, onnx_y_lens = session.run(["y", "y_lens"], inputs) + onnx_y = torch.from_numpy(onnx_y) + onnx_y_lens = torch.from_numpy(onnx_y_lens) + + torch_y, torch_y_lens = jit_model(x, x_lens) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( + (onnx_y - torch_y).abs().max() + ) + + assert torch.allclose(onnx_y_lens, torch_y_lens, atol=1e-05), ( + (onnx_y_lens - torch_y_lens).abs().max() + ) + print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) + print(onnx_y_lens, torch_y_lens) + + os.remove(filename) + + +@torch.no_grad() +def main(): + test_conv2d_subsampling() + test_rel_pos() + test_conformer_encoder_layer() + test_conformer_encoder() + test_conformer() + + +if __name__ == "__main__": + torch.manual_seed(20221011) + main() diff --git a/requirements.txt b/requirements.txt index 258c64065..d5931e49a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,5 +7,4 @@ multi_quantization onnx onnxruntime --extra-index-url https://pypi.ngc.nvidia.com -onnx_graphsurgeon dill