diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh index 8e485d2e6..999841b80 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh @@ -30,6 +30,15 @@ ln -s pretrained.pt epoch-99.pt ls -lh *.pt popd +log "Test exporting to ONNX format" +./pruned_transducer_stateless7/export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --onnx 1 + log "Export to torchscript model" ./pruned_transducer_stateless7/export.py \ --exp-dir $repo/exp \ @@ -41,6 +50,27 @@ log "Export to torchscript model" ls -lh $repo/exp/*.pt +log "Decode with ONNX models" + +./pruned_transducer_stateless7/onnx_check.py \ + --jit-filename $repo/exp/cpu_jit.pt \ + --onnx-encoder-filename $repo/exp/encoder.onnx \ + --onnx-decoder-filename $repo/exp/decoder.onnx \ + --onnx-joiner-filename $repo/exp/joiner.onnx \ + --onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \ + --onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx + +./pruned_transducer_stateless7/onnx_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --encoder-model-filename $repo/exp/encoder.onnx \ + --decoder-model-filename $repo/exp/decoder.onnx \ + --joiner-model-filename $repo/exp/joiner.onnx \ + --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \ + --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + log "Decode with models exported by torch.jit.script()" ./pruned_transducer_stateless7/jit_pretrained.py \ diff --git a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml b/.github/workflows/run-librispeech-2022-11-11-stateless7.yml index 365e2761a..7694e8bf5 100644 --- a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml +++ b/.github/workflows/run-librispeech-2022-11-11-stateless7.yml @@ -39,7 +39,7 @@ concurrency: jobs: run_librispeech_2022_11_11_zipformer: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' + if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' runs-on: ${{ matrix.os }} strategy: matrix: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py index 3e3160e7e..db8b5eb2b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py @@ -41,7 +41,31 @@ Check https://github.com/k2-fsa/sherpa for how to use the exported models outside of icefall. -(2) Export `model.state_dict()` +(2) Export to ONNX format + +./pruned_transducer_stateless7/export.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --onnx 1 + +It will generate the following files in the given `exp_dir`. +Check `onnx_check.py` for how to use them. + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + +Please see ./onnx_pretrained.py for usage of the generated files + +Check +https://github.com/k2-fsa/sherpa-onnx +for how to use the exported models outside of icefall. + +(3) Export `model.state_dict()` ./pruned_transducer_stateless7/export.py \ --exp-dir ./pruned_transducer_stateless7/exp \ @@ -172,6 +196,23 @@ def get_parser(): """, ) + parser.add_argument( + "--onnx", + type=str2bool, + default=False, + help="""If True, --jit is ignored and it exports the model + to onnx format. It will generate the following files: + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + + Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them. + """, + ) + parser.add_argument( "--context-size", type=int, @@ -184,6 +225,204 @@ def get_parser(): return parser +def export_encoder_model_onnx( + encoder_model: nn.Module, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T, C) + - encoder_out_lens, a tensor of shape (N,) + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 101, 80, dtype=torch.float32) + x_lens = torch.tensor([101], dtype=torch.int64) + + # encoder_model = torch.jit.script(encoder_model) + # It throws the following error for the above statement + # + # RuntimeError: Exporting the operator __is_ to ONNX opset version + # 11 is not supported. Please feel free to request support or + # submit a pull request on PyTorch GitHub. + # + # I cannot find which statement causes the above error. + # torch.onnx.export() will use torch.jit.trace() internally, which + # works well for the current reworked model + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_onnx( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, 1, C) + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = False # Always False, so we can use torch.jit.trace() here + # Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script() + # in this case + torch.onnx.export( + decoder_model, + (y, need_pad), + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y", "need_pad"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + - projected_decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + + The exported encoder_proj model has one input: + + - encoder_out: a tensor of shape (N, encoder_out_dim) + + and produces one output: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + + The exported decoder_proj model has one input: + + - decoder_out: a tensor of shape (N, decoder_out_dim) + + and produces one output: + + - projected_decoder_out: a tensor of shape (N, joiner_dim) + """ + encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") + decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") + + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + joiner_dim = joiner_model.decoder_proj.weight.shape[0] + + projected_encoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32) + + project_input = False + # Note: It uses torch.jit.trace() internally + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out, project_input), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + "project_input", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + logging.info(f"Saved to {joiner_filename}") + + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.encoder_proj, + encoder_out, + encoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["encoder_out"], + output_names=["projected_encoder_out"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "projected_encoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_proj_filename}") + + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.decoder_proj, + decoder_out, + decoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["decoder_out"], + output_names=["projected_decoder_out"], + dynamic_axes={ + "decoder_out": {0: "N"}, + "projected_decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_proj_filename}") + + @torch.no_grad() def main(): args = get_parser().parse_args() @@ -292,7 +531,31 @@ def main(): model.to("cpu") model.eval() - if params.jit is True: + if params.onnx is True: + convert_scaled_to_non_scaled(model, inplace=True) + opset_version = 13 + logging.info("Exporting to onnx format") + encoder_filename = params.exp_dir / "encoder.onnx" + export_encoder_model_onnx( + model.encoder, + encoder_filename, + opset_version=opset_version, + ) + + decoder_filename = params.exp_dir / "decoder.onnx" + export_decoder_model_onnx( + model.decoder, + decoder_filename, + opset_version=opset_version, + ) + + joiner_filename = params.exp_dir / "joiner.onnx" + export_joiner_model_onnx( + model.joiner, + joiner_filename, + opset_version=opset_version, + ) + elif params.jit is True: convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py new file mode 100755 index 000000000..63acc0922 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python3 +# +# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script checks that exported onnx models produce the same output +with the given torchscript model for the same input. +""" + +import argparse +import logging + +import onnxruntime as ort +import torch + +from icefall import is_module_available + +if not is_module_available("onnxruntime"): + raise ValueError("Please 'pip install onnxruntime' first.") + + +ort.set_default_logger_severity(3) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--jit-filename", + required=True, + type=str, + help="Path to the torchscript model", + ) + + parser.add_argument( + "--onnx-encoder-filename", + required=True, + type=str, + help="Path to the onnx encoder model", + ) + + parser.add_argument( + "--onnx-decoder-filename", + required=True, + type=str, + help="Path to the onnx decoder model", + ) + + parser.add_argument( + "--onnx-joiner-filename", + required=True, + type=str, + help="Path to the onnx joiner model", + ) + + parser.add_argument( + "--onnx-joiner-encoder-proj-filename", + required=True, + type=str, + help="Path to the onnx joiner encoder projection model", + ) + + parser.add_argument( + "--onnx-joiner-decoder-proj-filename", + required=True, + type=str, + help="Path to the onnx joiner decoder projection model", + ) + + return parser + + +def test_encoder( + model: torch.jit.ScriptModule, + encoder_session: ort.InferenceSession, +): + inputs = encoder_session.get_inputs() + outputs = encoder_session.get_outputs() + input_names = [n.name for n in inputs] + output_names = [n.name for n in outputs] + + assert inputs[0].shape == ["N", "T", 80] + assert inputs[1].shape == ["N"] + + for N in [1, 5]: + for T in [12, 50]: + print("N, T", N, T) + x = torch.rand(N, T, 80, dtype=torch.float32) + x_lens = torch.randint(low=10, high=T + 1, size=(N,)) + x_lens[0] = T + + encoder_inputs = { + input_names[0]: x.numpy(), + input_names[1]: x_lens.numpy(), + } + + torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens) + + encoder_out, encoder_out_lens = encoder_session.run( + output_names, + encoder_inputs, + ) + + torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens) + + encoder_out = torch.from_numpy(encoder_out) + assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), ( + (encoder_out - torch_encoder_out).abs().max(), + encoder_out.shape, + torch_encoder_out.shape, + ) + + +def test_decoder( + model: torch.jit.ScriptModule, + decoder_session: ort.InferenceSession, +): + inputs = decoder_session.get_inputs() + outputs = decoder_session.get_outputs() + input_names = [n.name for n in inputs] + output_names = [n.name for n in outputs] + + assert inputs[0].shape == ["N", 2] + for N in [1, 5, 10]: + y = torch.randint(low=1, high=500, size=(10, 2)) + + decoder_inputs = {input_names[0]: y.numpy()} + decoder_out = decoder_session.run( + output_names, + decoder_inputs, + )[0] + decoder_out = torch.from_numpy(decoder_out) + + torch_decoder_out = model.decoder(y, need_pad=False) + assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), ( + (decoder_out - torch_decoder_out).abs().max() + ) + + +def test_joiner( + model: torch.jit.ScriptModule, + joiner_session: ort.InferenceSession, + joiner_encoder_proj_session: ort.InferenceSession, + joiner_decoder_proj_session: ort.InferenceSession, +): + joiner_inputs = joiner_session.get_inputs() + joiner_outputs = joiner_session.get_outputs() + joiner_input_names = [n.name for n in joiner_inputs] + joiner_output_names = [n.name for n in joiner_outputs] + + assert joiner_inputs[0].shape == ["N", 1, 1, 512] + assert joiner_inputs[1].shape == ["N", 1, 1, 512] + + joiner_encoder_proj_inputs = joiner_encoder_proj_session.get_inputs() + encoder_proj_input_name = joiner_encoder_proj_inputs[0].name + + assert joiner_encoder_proj_inputs[0].shape == ["N", 384] + + joiner_encoder_proj_outputs = joiner_encoder_proj_session.get_outputs() + encoder_proj_output_name = joiner_encoder_proj_outputs[0].name + + joiner_decoder_proj_inputs = joiner_decoder_proj_session.get_inputs() + decoder_proj_input_name = joiner_decoder_proj_inputs[0].name + + assert joiner_decoder_proj_inputs[0].shape == ["N", 512] + + joiner_decoder_proj_outputs = joiner_decoder_proj_session.get_outputs() + decoder_proj_output_name = joiner_decoder_proj_outputs[0].name + + for N in [1, 5, 10]: + encoder_out = torch.rand(N, 384) + decoder_out = torch.rand(N, 512) + + projected_encoder_out = torch.rand(N, 1, 1, 512) + projected_decoder_out = torch.rand(N, 1, 1, 512) + + joiner_inputs = { + joiner_input_names[0]: projected_encoder_out.numpy(), + joiner_input_names[1]: projected_decoder_out.numpy(), + } + joiner_out = joiner_session.run(joiner_output_names, joiner_inputs)[0] + joiner_out = torch.from_numpy(joiner_out) + + torch_joiner_out = model.joiner( + projected_encoder_out, + projected_decoder_out, + project_input=False, + ) + assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), ( + (joiner_out - torch_joiner_out).abs().max() + ) + + # Now test encoder_proj + joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()} + joiner_encoder_proj_out = joiner_encoder_proj_session.run( + [encoder_proj_output_name], joiner_encoder_proj_inputs + )[0] + joiner_encoder_proj_out = torch.from_numpy(joiner_encoder_proj_out) + + torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out) + assert torch.allclose( + joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5 + ), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max()) + + # Now test decoder_proj + joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()} + joiner_decoder_proj_out = joiner_decoder_proj_session.run( + [decoder_proj_output_name], joiner_decoder_proj_inputs + )[0] + joiner_decoder_proj_out = torch.from_numpy(joiner_decoder_proj_out) + + torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out) + assert torch.allclose( + joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5 + ), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max()) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + logging.info(vars(args)) + + model = torch.jit.load(args.jit_filename) + + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + logging.info("Test encoder") + encoder_session = ort.InferenceSession( + args.onnx_encoder_filename, + sess_options=options, + ) + test_encoder(model, encoder_session) + + logging.info("Test decoder") + decoder_session = ort.InferenceSession( + args.onnx_decoder_filename, + sess_options=options, + ) + test_decoder(model, decoder_session) + + logging.info("Test joiner") + joiner_session = ort.InferenceSession( + args.onnx_joiner_filename, + sess_options=options, + ) + joiner_encoder_proj_session = ort.InferenceSession( + args.onnx_joiner_encoder_proj_filename, + sess_options=options, + ) + joiner_decoder_proj_session = ort.InferenceSession( + args.onnx_joiner_decoder_proj_filename, + sess_options=options, + ) + test_joiner( + model, + joiner_session, + joiner_encoder_proj_session, + joiner_decoder_proj_session, + ) + logging.info("Finished checking ONNX models") + + +if __name__ == "__main__": + torch.manual_seed(20220727) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_pretrained.py new file mode 100755 index 000000000..3a06ee293 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_pretrained.py @@ -0,0 +1,388 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX models and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7/export.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --onnx 1 + +Usage of this script: + +./pruned_transducer_stateless7/onnx_pretrained.py \ + --encoder-model-filename ./pruned_transducer_stateless7/exp/encoder.onnx \ + --decoder-model-filename ./pruned_transducer_stateless7/exp/decoder.onnx \ + --joiner-model-filename ./pruned_transducer_stateless7/exp/joiner.onnx \ + --joiner-encoder-proj-model-filename ./pruned_transducer_stateless7/exp/joiner_encoder_proj.onnx \ + --joiner-decoder-proj-model-filename ./pruned_transducer_stateless7/exp/joiner_decoder_proj.onnx \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import numpy as np +import onnxruntime as ort +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--joiner-encoder-proj-model-filename", + type=str, + required=True, + help="Path to the joiner encoder_proj onnx model. ", + ) + + parser.add_argument( + "--joiner-decoder-proj-model-filename", + type=str, + required=True, + help="Path to the joiner decoder_proj onnx model. ", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="Context size of the decoder model", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: ort.InferenceSession, + joiner: ort.InferenceSession, + joiner_encoder_proj: ort.InferenceSession, + joiner_decoder_proj: ort.InferenceSession, + encoder_out: np.ndarray, + encoder_out_lens: np.ndarray, + context_size: int, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + decoder: + The decoder model. + joiner: + The joiner model. + joiner_encoder_proj: + The joiner encoder projection model. + joiner_decoder_proj: + The joiner decoder projection model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + context_size: + The context size of the decoder model. + Returns: + Return the decoded results for each utterance. + """ + encoder_out = torch.from_numpy(encoder_out) + encoder_out_lens = torch.from_numpy(encoder_out_lens) + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + projected_encoder_out = joiner_encoder_proj.run( + [joiner_encoder_proj.get_outputs()[0].name], + {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()}, + )[0] + + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input_nodes = decoder.get_inputs() + decoder_output_nodes = decoder.get_outputs() + + joiner_input_nodes = joiner.get_inputs() + joiner_output_nodes = joiner.get_outputs() + + decoder_input = torch.tensor( + hyps, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = decoder.run( + [decoder_output_nodes[0].name], + { + decoder_input_nodes[0].name: decoder_input.numpy(), + }, + )[0].squeeze(1) + projected_decoder_out = joiner_decoder_proj.run( + [joiner_decoder_proj.get_outputs()[0].name], + {joiner_decoder_proj.get_inputs()[0].name: decoder_out}, + )[0] + + projected_decoder_out = torch.from_numpy(projected_decoder_out) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = projected_encoder_out[start:end] + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + projected_decoder_out = projected_decoder_out[:batch_size] + + logits = joiner.run( + [joiner_output_nodes[0].name], + { + joiner_input_nodes[0].name: np.expand_dims( + np.expand_dims(current_encoder_out, axis=1), axis=1 + ), + joiner_input_nodes[1] + .name: projected_decoder_out.unsqueeze(1) + .unsqueeze(1) + .numpy(), + }, + )[0] + logits = torch.from_numpy(logits).squeeze(1).squeeze(1) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + dtype=torch.int64, + ) + decoder_out = decoder.run( + [decoder_output_nodes[0].name], + { + decoder_input_nodes[0].name: decoder_input.numpy(), + }, + )[0].squeeze(1) + projected_decoder_out = joiner_decoder_proj.run( + [joiner_decoder_proj.get_outputs()[0].name], + {joiner_decoder_proj.get_inputs()[0].name: decoder_out}, + )[0] + projected_decoder_out = torch.from_numpy(projected_decoder_out) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + encoder = ort.InferenceSession( + args.encoder_model_filename, + sess_options=session_opts, + ) + + decoder = ort.InferenceSession( + args.decoder_model_filename, + sess_options=session_opts, + ) + + joiner = ort.InferenceSession( + args.joiner_model_filename, + sess_options=session_opts, + ) + + joiner_encoder_proj = ort.InferenceSession( + args.joiner_encoder_proj_model_filename, + sess_options=session_opts, + ) + + joiner_decoder_proj = ort.InferenceSession( + args.joiner_decoder_proj_model_filename, + sess_options=session_opts, + ) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + + encoder_input_nodes = encoder.get_inputs() + encoder_out_nodes = encoder.get_outputs() + encoder_out, encoder_out_lens = encoder.run( + [encoder_out_nodes[0].name, encoder_out_nodes[1].name], + { + encoder_input_nodes[0].name: features.numpy(), + encoder_input_nodes[1].name: feature_lengths.numpy(), + }, + ) + + hyps = greedy_search( + decoder=decoder, + joiner=joiner, + joiner_encoder_proj=joiner_encoder_proj, + joiner_decoder_proj=joiner_decoder_proj, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + context_size=args.context_size, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 1cbde6db0..156b91f09 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -261,7 +261,7 @@ class RandomGrad(torch.nn.Module): self.min_abs = min_abs def forward(self, x: Tensor): - if torch.jit.is_scripting() or not self.training: + if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing(): return x else: return RandomGradFunction.apply(x, self.min_abs) @@ -530,7 +530,7 @@ class ActivationBalancer(torch.nn.Module): self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting() or not x.requires_grad: + if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing(): return _no_op(x) count = self.cpu_count @@ -790,7 +790,7 @@ def with_loss(x, y): def _no_op(x: Tensor) -> Tensor: - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): return x else: # a no-op function that will have a node in the autograd graph, @@ -862,6 +862,7 @@ class MaxEig(torch.nn.Module): torch.jit.is_scripting() or self.max_var_per_eig <= 0 or random.random() > self.cur_prob + or torch.jit.is_tracing() ): return _no_op(x) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py new file mode 100644 index 000000000..2440d267c --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py @@ -0,0 +1,374 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file is to test that models can be exported to onnx. +""" +import os + +from icefall import is_module_available + +if not is_module_available("onnxruntime"): + raise ValueError("Please 'pip install onnxruntime' first.") + +import onnxruntime as ort +import torch +from scaling_converter import convert_scaled_to_non_scaled +from zipformer import ( + Conv2dSubsampling, + RelPositionalEncoding, + Zipformer, + ZipformerEncoder, + ZipformerEncoderLayer, +) + +ort.set_default_logger_severity(3) + + +def test_conv2d_subsampling(): + filename = "conv2d_subsampling.onnx" + opset_version = 13 + N = 30 + T = 50 + num_features = 80 + d_model = 512 + x = torch.rand(N, T, num_features) + + encoder_embed = Conv2dSubsampling(num_features, d_model) + encoder_embed.eval() + encoder_embed = convert_scaled_to_non_scaled(encoder_embed, inplace=True) + + torch.onnx.export( + encoder_embed, + x, + filename, + verbose=False, + opset_version=opset_version, + input_names=["x"], + output_names=["y"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "y": {0: "N", 1: "T"}, + }, + ) + + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + session = ort.InferenceSession( + filename, + sess_options=options, + ) + + input_nodes = session.get_inputs() + assert input_nodes[0].name == "x" + assert input_nodes[0].shape == ["N", "T", num_features] + + inputs = {input_nodes[0].name: x.numpy()} + + onnx_y = session.run(["y"], inputs)[0] + + onnx_y = torch.from_numpy(onnx_y) + torch_y = encoder_embed(x) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + + os.remove(filename) + + +def test_rel_pos(): + filename = "rel_pos.onnx" + + opset_version = 13 + N = 30 + T = 50 + num_features = 80 + d_model = 512 + x = torch.rand(N, T, num_features) + + encoder_pos = RelPositionalEncoding(d_model, dropout_rate=0.1) + encoder_pos.eval() + encoder_pos = convert_scaled_to_non_scaled(encoder_pos, inplace=True) + + x = x.permute(1, 0, 2) + + torch.onnx.export( + encoder_pos, + x, + filename, + verbose=False, + opset_version=opset_version, + input_names=["x"], + output_names=["pos_emb"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "pos_emb": {0: "N", 1: "T"}, + }, + ) + + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + session = ort.InferenceSession( + filename, + sess_options=options, + ) + + input_nodes = session.get_inputs() + assert input_nodes[0].name == "x" + assert input_nodes[0].shape == ["N", "T", num_features] + + inputs = {input_nodes[0].name: x.numpy()} + onnx_pos_emb = session.run(["pos_emb"], inputs) + onnx_pos_emb = torch.from_numpy(onnx_pos_emb[0]) + + torch_pos_emb = encoder_pos(x) + assert torch.allclose(onnx_pos_emb, torch_pos_emb, atol=1e-05), ( + (onnx_pos_emb - torch_pos_emb).abs().max() + ) + print(onnx_pos_emb.abs().sum(), torch_pos_emb.abs().sum()) + + os.remove(filename) + + +def test_zipformer_encoder_layer(): + filename = "zipformer_encoder_layer.onnx" + opset_version = 13 + N = 30 + T = 50 + + d_model = 384 + attention_dim = 192 + nhead = 8 + feedforward_dim = 1024 + dropout = 0.1 + cnn_module_kernel = 31 + pos_dim = 4 + + x = torch.rand(N, T, d_model) + + encoder_pos = RelPositionalEncoding(d_model, dropout) + encoder_pos.eval() + encoder_pos = convert_scaled_to_non_scaled(encoder_pos, inplace=True) + + x = x.permute(1, 0, 2) + pos_emb = encoder_pos(x) + + encoder_layer = ZipformerEncoderLayer( + d_model, + attention_dim, + nhead, + feedforward_dim, + dropout, + cnn_module_kernel, + pos_dim, + ) + encoder_layer.eval() + encoder_layer = convert_scaled_to_non_scaled(encoder_layer, inplace=True) + + torch.onnx.export( + encoder_layer, + (x, pos_emb), + filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "pos_emb"], + output_names=["y"], + dynamic_axes={ + "x": {0: "T", 1: "N"}, + "pos_emb": {0: "N", 1: "T"}, + "y": {0: "T", 1: "N"}, + }, + ) + + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + session = ort.InferenceSession( + filename, + sess_options=options, + ) + + input_nodes = session.get_inputs() + inputs = { + input_nodes[0].name: x.numpy(), + input_nodes[1].name: pos_emb.numpy(), + } + onnx_y = session.run(["y"], inputs)[0] + onnx_y = torch.from_numpy(onnx_y) + + torch_y = encoder_layer(x, pos_emb) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + + print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) + + os.remove(filename) + + +def test_zipformer_encoder(): + filename = "zipformer_encoder.onnx" + + opset_version = 13 + N = 3 + T = 15 + + d_model = 512 + attention_dim = 192 + nhead = 8 + feedforward_dim = 1024 + dropout = 0.1 + cnn_module_kernel = 31 + pos_dim = 4 + num_encoder_layers = 12 + + warmup_batches = 4000.0 + warmup_begin = warmup_batches / (num_encoder_layers + 1) + warmup_end = warmup_batches / (num_encoder_layers + 1) + + x = torch.rand(N, T, d_model) + + encoder_layer = ZipformerEncoderLayer( + d_model, + attention_dim, + nhead, + feedforward_dim, + dropout, + cnn_module_kernel, + pos_dim, + ) + encoder = ZipformerEncoder( + encoder_layer, num_encoder_layers, dropout, warmup_begin, warmup_end + ) + encoder.eval() + encoder = convert_scaled_to_non_scaled(encoder, inplace=True) + + # jit_model = torch.jit.trace(encoder, (pos_emb)) + + torch_y = encoder(x) + + torch.onnx.export( + encoder, + (x), + filename, + verbose=False, + opset_version=opset_version, + input_names=["x"], + output_names=["y"], + dynamic_axes={ + "x": {0: "T", 1: "N"}, + "y": {0: "T", 1: "N"}, + }, + ) + + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + session = ort.InferenceSession( + filename, + sess_options=options, + ) + + input_nodes = session.get_inputs() + inputs = { + input_nodes[0].name: x.numpy(), + } + onnx_y = session.run(["y"], inputs)[0] + onnx_y = torch.from_numpy(onnx_y) + + torch_y = encoder(x) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + + print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) + + os.remove(filename) + + +def test_zipformer(): + filename = "zipformer.onnx" + opset_version = 11 + N = 3 + T = 15 + num_features = 80 + x = torch.rand(N, T, num_features) + x_lens = torch.full((N,), fill_value=T, dtype=torch.int64) + + zipformer = Zipformer(num_features=num_features) + zipformer.eval() + zipformer = convert_scaled_to_non_scaled(zipformer, inplace=True) + + # jit_model = torch.jit.trace(zipformer, (x, x_lens)) + torch.onnx.export( + zipformer, + (x, x_lens), + filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["y", "y_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "y": {0: "N", 1: "T"}, + "y_lens": {0: "N"}, + }, + ) + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + session = ort.InferenceSession( + filename, + sess_options=options, + ) + + input_nodes = session.get_inputs() + inputs = { + input_nodes[0].name: x.numpy(), + input_nodes[1].name: x_lens.numpy(), + } + onnx_y, onnx_y_lens = session.run(["y", "y_lens"], inputs) + onnx_y = torch.from_numpy(onnx_y) + onnx_y_lens = torch.from_numpy(onnx_y_lens) + + torch_y, torch_y_lens = zipformer(x, x_lens) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + + assert torch.allclose(onnx_y_lens, torch_y_lens, atol=1e-05), ( + (onnx_y_lens - torch_y_lens).abs().max() + ) + print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) + print(onnx_y_lens, torch_y_lens) + + os.remove(filename) + + +@torch.no_grad() +def main(): + test_conv2d_subsampling() + test_rel_pos() + test_zipformer_encoder_layer() + test_zipformer_encoder() + test_zipformer() + + +if __name__ == "__main__": + torch.manual_seed(20221011) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index d18258085..b1717ec64 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -210,7 +210,7 @@ class Zipformer(EncoderInterface): (num_frames, batch_size, encoder_dims0) """ num_encoders = len(self.encoder_dims) - if torch.jit.is_scripting() or not self.training: + if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing(): return [1.0] * num_encoders (num_frames0, batch_size, _encoder_dims0) = x.shape @@ -293,7 +293,7 @@ class Zipformer(EncoderInterface): k = self.skip_layers[i] if isinstance(k, int): layer_skip_dropout_prob = self._get_layer_skip_dropout_prob() - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): x = skip_module(outputs[k], x) elif (not self.training) or random.random() > layer_skip_dropout_prob: x = skip_module(outputs[k], x) @@ -386,7 +386,7 @@ class ZipformerEncoderLayer(nn.Module): ) def get_bypass_scale(self): - if torch.jit.is_scripting() or not self.training: + if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing(): return self.bypass_scale if random.random() < 0.1: # ensure we get grads if self.bypass_scale becomes out of range @@ -407,7 +407,7 @@ class ZipformerEncoderLayer(nn.Module): # return dropout rate for the dynamic modules (self_attn, pooling, convolution); this # starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable # at the beginning, by making the network focus on the feedforward modules. - if torch.jit.is_scripting() or not self.training: + if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing(): return 0.0 warmup_period = 2000.0 initial_dropout_rate = 0.2 @@ -452,12 +452,12 @@ class ZipformerEncoderLayer(nn.Module): dynamic_dropout = self.get_dynamic_dropout_rate() # pooling module - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): src = src + self.pooling(src, key_padding_mask=src_key_padding_mask) elif random.random() >= dynamic_dropout: src = src + self.pooling(src, key_padding_mask=src_key_padding_mask) - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): src_att, attn_weights = self.self_attn( src, pos_emb=pos_emb, @@ -658,7 +658,7 @@ class ZipformerEncoder(nn.Module): pos_emb = self.encoder_pos(src) output = src - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): layers_to_drop = [] else: rnd_seed = src.numel() + random.randint(0, 1000) @@ -667,7 +667,7 @@ class ZipformerEncoder(nn.Module): output = output * feature_mask for i, mod in enumerate(self.layers): - if not torch.jit.is_scripting(): + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if i in layers_to_drop: continue output = mod( @@ -864,7 +864,7 @@ class SimpleCombiner(torch.nn.Module): assert src1.shape[:-1] == src2.shape[:-1], (src1.shape, src2.shape) weight1 = self.weight1 - if not torch.jit.is_scripting(): + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if ( self.training and random.random() < 0.25 @@ -1258,21 +1258,31 @@ class RelPositionMultiheadAttention(nn.Module): # the following .as_strided() expression converts the last axis of pos_weights from relative # to absolute position. I don't know whether I might have got the time-offsets backwards or # not, but let this code define which way round it is supposed to be. - pos_weights = pos_weights.as_strided( - (bsz, num_heads, seq_len, seq_len), - ( - pos_weights.stride(0), - pos_weights.stride(1), - pos_weights.stride(2) - pos_weights.stride(3), - pos_weights.stride(3), - ), - storage_offset=pos_weights.stride(3) * (seq_len - 1), - ) + if torch.jit.is_tracing(): + (batch_size, num_heads, time1, n) = pos_weights.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(seq_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_weights = pos_weights.reshape(-1, n) + pos_weights = torch.gather(pos_weights, dim=1, index=indexes) + pos_weights = pos_weights.reshape(batch_size, num_heads, time1, seq_len) + else: + pos_weights = pos_weights.as_strided( + (bsz, num_heads, seq_len, seq_len), + ( + pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2) - pos_weights.stride(3), + pos_weights.stride(3), + ), + storage_offset=pos_weights.stride(3) * (seq_len - 1), + ) # caution: they are really scores at this point. attn_output_weights = torch.matmul(q, k) + pos_weights - if not torch.jit.is_scripting(): + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if training and random.random() < 0.1: # This is a harder way of limiting the attention scores to not be too large. # It incurs a penalty if any of them has an absolute value greater than 50.0. @@ -1383,7 +1393,7 @@ class RelPositionMultiheadAttention(nn.Module): # now v: (bsz * num_heads, seq_len, head_dim // 2) attn_output = torch.bmm(attn_weights, v) - if not torch.jit.is_scripting(): + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if random.random() < 0.001 or __name__ == "__main__": self._print_attn_stats(attn_weights, attn_output) @@ -1458,7 +1468,10 @@ class PoolingModule(nn.Module): a Tensor of shape (1, N, C) """ if key_padding_mask is not None: - pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T) + if torch.jit.is_tracing(): + pooling_mask = (~key_padding_mask).to(x.dtype) + else: + pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T) pooling_mask = pooling_mask / pooling_mask.sum(dim=1, keepdim=True) pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1) # now pooling_mask: (T, N, 1)