Add Zipformer Onnx Support (#778)

* add export script

* add zipformer onnx pretrained script

* add onnx zipformer test

* fix style

* add zipformer onnx to workflow

* replace is_in_onnx_export with is_tracing

* add github.event.label.name == 'onnx'

* add is_tracing to necessary conditions

* fix pooling_mask

* add onnx_check

* add onnx_check to scripts

* add is_tracing to scaling.py
This commit is contained in:
Yunusemre 2023-01-03 08:59:44 +00:00 committed by GitHub
parent 80cce141b4
commit 0f26edfde9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 1383 additions and 28 deletions

View File

@ -30,6 +30,15 @@ ln -s pretrained.pt epoch-99.pt
ls -lh *.pt ls -lh *.pt
popd popd
log "Test exporting to ONNX format"
./pruned_transducer_stateless7/export.py \
--exp-dir $repo/exp \
--use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--onnx 1
log "Export to torchscript model" log "Export to torchscript model"
./pruned_transducer_stateless7/export.py \ ./pruned_transducer_stateless7/export.py \
--exp-dir $repo/exp \ --exp-dir $repo/exp \
@ -41,6 +50,27 @@ log "Export to torchscript model"
ls -lh $repo/exp/*.pt ls -lh $repo/exp/*.pt
log "Decode with ONNX models"
./pruned_transducer_stateless7/onnx_check.py \
--jit-filename $repo/exp/cpu_jit.pt \
--onnx-encoder-filename $repo/exp/encoder.onnx \
--onnx-decoder-filename $repo/exp/decoder.onnx \
--onnx-joiner-filename $repo/exp/joiner.onnx \
--onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \
--onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx
./pruned_transducer_stateless7/onnx_pretrained.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--encoder-model-filename $repo/exp/encoder.onnx \
--decoder-model-filename $repo/exp/decoder.onnx \
--joiner-model-filename $repo/exp/joiner.onnx \
--joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \
--joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
log "Decode with models exported by torch.jit.script()" log "Decode with models exported by torch.jit.script()"
./pruned_transducer_stateless7/jit_pretrained.py \ ./pruned_transducer_stateless7/jit_pretrained.py \

View File

@ -39,7 +39,7 @@ concurrency:
jobs: jobs:
run_librispeech_2022_11_11_zipformer: run_librispeech_2022_11_11_zipformer:
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
strategy: strategy:
matrix: matrix:

View File

@ -41,7 +41,31 @@ Check
https://github.com/k2-fsa/sherpa https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall. for how to use the exported models outside of icefall.
(2) Export `model.state_dict()` (2) Export to ONNX format
./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10 \
--onnx 1
It will generate the following files in the given `exp_dir`.
Check `onnx_check.py` for how to use them.
- encoder.onnx
- decoder.onnx
- joiner.onnx
- joiner_encoder_proj.onnx
- joiner_decoder_proj.onnx
Please see ./onnx_pretrained.py for usage of the generated files
Check
https://github.com/k2-fsa/sherpa-onnx
for how to use the exported models outside of icefall.
(3) Export `model.state_dict()`
./pruned_transducer_stateless7/export.py \ ./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \ --exp-dir ./pruned_transducer_stateless7/exp \
@ -172,6 +196,23 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--onnx",
type=str2bool,
default=False,
help="""If True, --jit is ignored and it exports the model
to onnx format. It will generate the following files:
- encoder.onnx
- decoder.onnx
- joiner.onnx
- joiner_encoder_proj.onnx
- joiner_decoder_proj.onnx
Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them.
""",
)
parser.add_argument( parser.add_argument(
"--context-size", "--context-size",
type=int, type=int,
@ -184,6 +225,204 @@ def get_parser():
return parser return parser
def export_encoder_model_onnx(
encoder_model: nn.Module,
encoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the given encoder model to ONNX format.
The exported model has two inputs:
- x, a tensor of shape (N, T, C); dtype is torch.float32
- x_lens, a tensor of shape (N,); dtype is torch.int64
and it has two outputs:
- encoder_out, a tensor of shape (N, T, C)
- encoder_out_lens, a tensor of shape (N,)
Note: The warmup argument is fixed to 1.
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
x = torch.zeros(1, 101, 80, dtype=torch.float32)
x_lens = torch.tensor([101], dtype=torch.int64)
# encoder_model = torch.jit.script(encoder_model)
# It throws the following error for the above statement
#
# RuntimeError: Exporting the operator __is_ to ONNX opset version
# 11 is not supported. Please feel free to request support or
# submit a pull request on PyTorch GitHub.
#
# I cannot find which statement causes the above error.
# torch.onnx.export() will use torch.jit.trace() internally, which
# works well for the current reworked model
torch.onnx.export(
encoder_model,
(x, x_lens),
encoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "x_lens"],
output_names=["encoder_out", "encoder_out_lens"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"encoder_out": {0: "N", 1: "T"},
"encoder_out_lens": {0: "N"},
},
)
logging.info(f"Saved to {encoder_filename}")
def export_decoder_model_onnx(
decoder_model: nn.Module,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.
The exported model has one input:
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
and has one output:
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
Note: The argument need_pad is fixed to False.
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
need_pad = False # Always False, so we can use torch.jit.trace() here
# Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script()
# in this case
torch.onnx.export(
decoder_model,
(y, need_pad),
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y", "need_pad"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
)
logging.info(f"Saved to {decoder_filename}")
def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported joiner model has two inputs:
- projected_encoder_out: a tensor of shape (N, joiner_dim)
- projected_decoder_out: a tensor of shape (N, joiner_dim)
and produces one output:
- logit: a tensor of shape (N, vocab_size)
The exported encoder_proj model has one input:
- encoder_out: a tensor of shape (N, encoder_out_dim)
and produces one output:
- projected_encoder_out: a tensor of shape (N, joiner_dim)
The exported decoder_proj model has one input:
- decoder_out: a tensor of shape (N, decoder_out_dim)
and produces one output:
- projected_decoder_out: a tensor of shape (N, joiner_dim)
"""
encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
joiner_dim = joiner_model.decoder_proj.weight.shape[0]
projected_encoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32)
project_input = False
# Note: It uses torch.jit.trace() internally
torch.onnx.export(
joiner_model,
(projected_encoder_out, projected_decoder_out, project_input),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"encoder_out",
"decoder_out",
"project_input",
],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
logging.info(f"Saved to {joiner_filename}")
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model.encoder_proj,
encoder_out,
encoder_proj_filename,
verbose=False,
opset_version=opset_version,
input_names=["encoder_out"],
output_names=["projected_encoder_out"],
dynamic_axes={
"encoder_out": {0: "N"},
"projected_encoder_out": {0: "N"},
},
)
logging.info(f"Saved to {encoder_proj_filename}")
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model.decoder_proj,
decoder_out,
decoder_proj_filename,
verbose=False,
opset_version=opset_version,
input_names=["decoder_out"],
output_names=["projected_decoder_out"],
dynamic_axes={
"decoder_out": {0: "N"},
"projected_decoder_out": {0: "N"},
},
)
logging.info(f"Saved to {decoder_proj_filename}")
@torch.no_grad() @torch.no_grad()
def main(): def main():
args = get_parser().parse_args() args = get_parser().parse_args()
@ -292,7 +531,31 @@ def main():
model.to("cpu") model.to("cpu")
model.eval() model.eval()
if params.jit is True: if params.onnx is True:
convert_scaled_to_non_scaled(model, inplace=True)
opset_version = 13
logging.info("Exporting to onnx format")
encoder_filename = params.exp_dir / "encoder.onnx"
export_encoder_model_onnx(
model.encoder,
encoder_filename,
opset_version=opset_version,
)
decoder_filename = params.exp_dir / "decoder.onnx"
export_decoder_model_onnx(
model.decoder,
decoder_filename,
opset_version=opset_version,
)
joiner_filename = params.exp_dir / "joiner.onnx"
export_joiner_model_onnx(
model.joiner,
joiner_filename,
opset_version=opset_version,
)
elif params.jit is True:
convert_scaled_to_non_scaled(model, inplace=True) convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore # We won't use the forward() method of the model in C++, so just ignore
# it here. # it here.

View File

@ -0,0 +1,286 @@
#!/usr/bin/env python3
#
# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script checks that exported onnx models produce the same output
with the given torchscript model for the same input.
"""
import argparse
import logging
import onnxruntime as ort
import torch
from icefall import is_module_available
if not is_module_available("onnxruntime"):
raise ValueError("Please 'pip install onnxruntime' first.")
ort.set_default_logger_severity(3)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--jit-filename",
required=True,
type=str,
help="Path to the torchscript model",
)
parser.add_argument(
"--onnx-encoder-filename",
required=True,
type=str,
help="Path to the onnx encoder model",
)
parser.add_argument(
"--onnx-decoder-filename",
required=True,
type=str,
help="Path to the onnx decoder model",
)
parser.add_argument(
"--onnx-joiner-filename",
required=True,
type=str,
help="Path to the onnx joiner model",
)
parser.add_argument(
"--onnx-joiner-encoder-proj-filename",
required=True,
type=str,
help="Path to the onnx joiner encoder projection model",
)
parser.add_argument(
"--onnx-joiner-decoder-proj-filename",
required=True,
type=str,
help="Path to the onnx joiner decoder projection model",
)
return parser
def test_encoder(
model: torch.jit.ScriptModule,
encoder_session: ort.InferenceSession,
):
inputs = encoder_session.get_inputs()
outputs = encoder_session.get_outputs()
input_names = [n.name for n in inputs]
output_names = [n.name for n in outputs]
assert inputs[0].shape == ["N", "T", 80]
assert inputs[1].shape == ["N"]
for N in [1, 5]:
for T in [12, 50]:
print("N, T", N, T)
x = torch.rand(N, T, 80, dtype=torch.float32)
x_lens = torch.randint(low=10, high=T + 1, size=(N,))
x_lens[0] = T
encoder_inputs = {
input_names[0]: x.numpy(),
input_names[1]: x_lens.numpy(),
}
torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens)
encoder_out, encoder_out_lens = encoder_session.run(
output_names,
encoder_inputs,
)
torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens)
encoder_out = torch.from_numpy(encoder_out)
assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), (
(encoder_out - torch_encoder_out).abs().max(),
encoder_out.shape,
torch_encoder_out.shape,
)
def test_decoder(
model: torch.jit.ScriptModule,
decoder_session: ort.InferenceSession,
):
inputs = decoder_session.get_inputs()
outputs = decoder_session.get_outputs()
input_names = [n.name for n in inputs]
output_names = [n.name for n in outputs]
assert inputs[0].shape == ["N", 2]
for N in [1, 5, 10]:
y = torch.randint(low=1, high=500, size=(10, 2))
decoder_inputs = {input_names[0]: y.numpy()}
decoder_out = decoder_session.run(
output_names,
decoder_inputs,
)[0]
decoder_out = torch.from_numpy(decoder_out)
torch_decoder_out = model.decoder(y, need_pad=False)
assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), (
(decoder_out - torch_decoder_out).abs().max()
)
def test_joiner(
model: torch.jit.ScriptModule,
joiner_session: ort.InferenceSession,
joiner_encoder_proj_session: ort.InferenceSession,
joiner_decoder_proj_session: ort.InferenceSession,
):
joiner_inputs = joiner_session.get_inputs()
joiner_outputs = joiner_session.get_outputs()
joiner_input_names = [n.name for n in joiner_inputs]
joiner_output_names = [n.name for n in joiner_outputs]
assert joiner_inputs[0].shape == ["N", 1, 1, 512]
assert joiner_inputs[1].shape == ["N", 1, 1, 512]
joiner_encoder_proj_inputs = joiner_encoder_proj_session.get_inputs()
encoder_proj_input_name = joiner_encoder_proj_inputs[0].name
assert joiner_encoder_proj_inputs[0].shape == ["N", 384]
joiner_encoder_proj_outputs = joiner_encoder_proj_session.get_outputs()
encoder_proj_output_name = joiner_encoder_proj_outputs[0].name
joiner_decoder_proj_inputs = joiner_decoder_proj_session.get_inputs()
decoder_proj_input_name = joiner_decoder_proj_inputs[0].name
assert joiner_decoder_proj_inputs[0].shape == ["N", 512]
joiner_decoder_proj_outputs = joiner_decoder_proj_session.get_outputs()
decoder_proj_output_name = joiner_decoder_proj_outputs[0].name
for N in [1, 5, 10]:
encoder_out = torch.rand(N, 384)
decoder_out = torch.rand(N, 512)
projected_encoder_out = torch.rand(N, 1, 1, 512)
projected_decoder_out = torch.rand(N, 1, 1, 512)
joiner_inputs = {
joiner_input_names[0]: projected_encoder_out.numpy(),
joiner_input_names[1]: projected_decoder_out.numpy(),
}
joiner_out = joiner_session.run(joiner_output_names, joiner_inputs)[0]
joiner_out = torch.from_numpy(joiner_out)
torch_joiner_out = model.joiner(
projected_encoder_out,
projected_decoder_out,
project_input=False,
)
assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), (
(joiner_out - torch_joiner_out).abs().max()
)
# Now test encoder_proj
joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()}
joiner_encoder_proj_out = joiner_encoder_proj_session.run(
[encoder_proj_output_name], joiner_encoder_proj_inputs
)[0]
joiner_encoder_proj_out = torch.from_numpy(joiner_encoder_proj_out)
torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out)
assert torch.allclose(
joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5
), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max())
# Now test decoder_proj
joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()}
joiner_decoder_proj_out = joiner_decoder_proj_session.run(
[decoder_proj_output_name], joiner_decoder_proj_inputs
)[0]
joiner_decoder_proj_out = torch.from_numpy(joiner_decoder_proj_out)
torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out)
assert torch.allclose(
joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5
), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max())
@torch.no_grad()
def main():
args = get_parser().parse_args()
logging.info(vars(args))
model = torch.jit.load(args.jit_filename)
options = ort.SessionOptions()
options.inter_op_num_threads = 1
options.intra_op_num_threads = 1
logging.info("Test encoder")
encoder_session = ort.InferenceSession(
args.onnx_encoder_filename,
sess_options=options,
)
test_encoder(model, encoder_session)
logging.info("Test decoder")
decoder_session = ort.InferenceSession(
args.onnx_decoder_filename,
sess_options=options,
)
test_decoder(model, decoder_session)
logging.info("Test joiner")
joiner_session = ort.InferenceSession(
args.onnx_joiner_filename,
sess_options=options,
)
joiner_encoder_proj_session = ort.InferenceSession(
args.onnx_joiner_encoder_proj_filename,
sess_options=options,
)
joiner_decoder_proj_session = ort.InferenceSession(
args.onnx_joiner_decoder_proj_filename,
sess_options=options,
)
test_joiner(
model,
joiner_session,
joiner_encoder_proj_session,
joiner_decoder_proj_session,
)
logging.info("Finished checking ONNX models")
if __name__ == "__main__":
torch.manual_seed(20220727)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,388 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads ONNX models and uses them to decode waves.
You can use the following command to get the exported models:
./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10 \
--onnx 1
Usage of this script:
./pruned_transducer_stateless7/onnx_pretrained.py \
--encoder-model-filename ./pruned_transducer_stateless7/exp/encoder.onnx \
--decoder-model-filename ./pruned_transducer_stateless7/exp/decoder.onnx \
--joiner-model-filename ./pruned_transducer_stateless7/exp/joiner.onnx \
--joiner-encoder-proj-model-filename ./pruned_transducer_stateless7/exp/joiner_encoder_proj.onnx \
--joiner-decoder-proj-model-filename ./pruned_transducer_stateless7/exp/joiner_decoder_proj.onnx \
--bpe-model ./data/lang_bpe_500/bpe.model \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
import logging
import math
from typing import List
import kaldifeat
import numpy as np
import onnxruntime as ort
import sentencepiece as spm
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder onnx model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder onnx model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner onnx model. ",
)
parser.add_argument(
"--joiner-encoder-proj-model-filename",
type=str,
required=True,
help="Path to the joiner encoder_proj onnx model. ",
)
parser.add_argument(
"--joiner-decoder-proj-model-filename",
type=str,
required=True,
help="Path to the joiner decoder_proj onnx model. ",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="Context size of the decoder model",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
def greedy_search(
decoder: ort.InferenceSession,
joiner: ort.InferenceSession,
joiner_encoder_proj: ort.InferenceSession,
joiner_decoder_proj: ort.InferenceSession,
encoder_out: np.ndarray,
encoder_out_lens: np.ndarray,
context_size: int,
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
decoder:
The decoder model.
joiner:
The joiner model.
joiner_encoder_proj:
The joiner encoder projection model.
joiner_decoder_proj:
The joiner decoder projection model.
encoder_out:
A 3-D tensor of shape (N, T, C)
encoder_out_lens:
A 1-D tensor of shape (N,).
context_size:
The context size of the decoder model.
Returns:
Return the decoded results for each utterance.
"""
encoder_out = torch.from_numpy(encoder_out)
encoder_out_lens = torch.from_numpy(encoder_out_lens)
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
projected_encoder_out = joiner_encoder_proj.run(
[joiner_encoder_proj.get_outputs()[0].name],
{joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()},
)[0]
blank_id = 0 # hard-code to 0
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
hyps = [[blank_id] * context_size for _ in range(N)]
decoder_input_nodes = decoder.get_inputs()
decoder_output_nodes = decoder.get_outputs()
joiner_input_nodes = joiner.get_inputs()
joiner_output_nodes = joiner.get_outputs()
decoder_input = torch.tensor(
hyps,
dtype=torch.int64,
) # (N, context_size)
decoder_out = decoder.run(
[decoder_output_nodes[0].name],
{
decoder_input_nodes[0].name: decoder_input.numpy(),
},
)[0].squeeze(1)
projected_decoder_out = joiner_decoder_proj.run(
[joiner_decoder_proj.get_outputs()[0].name],
{joiner_decoder_proj.get_inputs()[0].name: decoder_out},
)[0]
projected_decoder_out = torch.from_numpy(projected_decoder_out)
offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = projected_encoder_out[start:end]
# current_encoder_out's shape: (batch_size, encoder_out_dim)
offset = end
projected_decoder_out = projected_decoder_out[:batch_size]
logits = joiner.run(
[joiner_output_nodes[0].name],
{
joiner_input_nodes[0].name: np.expand_dims(
np.expand_dims(current_encoder_out, axis=1), axis=1
),
joiner_input_nodes[1]
.name: projected_decoder_out.unsqueeze(1)
.unsqueeze(1)
.numpy(),
},
)[0]
logits = torch.from_numpy(logits).squeeze(1).squeeze(1)
# logits'shape (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor(
decoder_input,
dtype=torch.int64,
)
decoder_out = decoder.run(
[decoder_output_nodes[0].name],
{
decoder_input_nodes[0].name: decoder_input.numpy(),
},
)[0].squeeze(1)
projected_decoder_out = joiner_decoder_proj.run(
[joiner_decoder_proj.get_outputs()[0].name],
{joiner_decoder_proj.get_inputs()[0].name: decoder_out},
)[0]
projected_decoder_out = torch.from_numpy(projected_decoder_out)
sorted_ans = [h[context_size:] for h in hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
encoder = ort.InferenceSession(
args.encoder_model_filename,
sess_options=session_opts,
)
decoder = ort.InferenceSession(
args.decoder_model_filename,
sess_options=session_opts,
)
joiner = ort.InferenceSession(
args.joiner_model_filename,
sess_options=session_opts,
)
joiner_encoder_proj = ort.InferenceSession(
args.joiner_encoder_proj_model_filename,
sess_options=session_opts,
)
joiner_decoder_proj = ort.InferenceSession(
args.joiner_decoder_proj_model_filename,
sess_options=session_opts,
)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files,
expected_sample_rate=args.sample_rate,
)
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10),
)
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
encoder_input_nodes = encoder.get_inputs()
encoder_out_nodes = encoder.get_outputs()
encoder_out, encoder_out_lens = encoder.run(
[encoder_out_nodes[0].name, encoder_out_nodes[1].name],
{
encoder_input_nodes[0].name: features.numpy(),
encoder_input_nodes[1].name: feature_lengths.numpy(),
},
)
hyps = greedy_search(
decoder=decoder,
joiner=joiner,
joiner_encoder_proj=joiner_encoder_proj,
joiner_decoder_proj=joiner_decoder_proj,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
context_size=args.context_size,
)
s = "\n"
for filename, hyp in zip(args.sound_files, hyps):
words = sp.decode(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -261,7 +261,7 @@ class RandomGrad(torch.nn.Module):
self.min_abs = min_abs self.min_abs = min_abs
def forward(self, x: Tensor): def forward(self, x: Tensor):
if torch.jit.is_scripting() or not self.training: if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing():
return x return x
else: else:
return RandomGradFunction.apply(x, self.min_abs) return RandomGradFunction.apply(x, self.min_abs)
@ -530,7 +530,7 @@ class ActivationBalancer(torch.nn.Module):
self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or not x.requires_grad: if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing():
return _no_op(x) return _no_op(x)
count = self.cpu_count count = self.cpu_count
@ -790,7 +790,7 @@ def with_loss(x, y):
def _no_op(x: Tensor) -> Tensor: def _no_op(x: Tensor) -> Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting() or torch.jit.is_tracing():
return x return x
else: else:
# a no-op function that will have a node in the autograd graph, # a no-op function that will have a node in the autograd graph,
@ -862,6 +862,7 @@ class MaxEig(torch.nn.Module):
torch.jit.is_scripting() torch.jit.is_scripting()
or self.max_var_per_eig <= 0 or self.max_var_per_eig <= 0
or random.random() > self.cur_prob or random.random() > self.cur_prob
or torch.jit.is_tracing()
): ):
return _no_op(x) return _no_op(x)

View File

@ -0,0 +1,374 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file is to test that models can be exported to onnx.
"""
import os
from icefall import is_module_available
if not is_module_available("onnxruntime"):
raise ValueError("Please 'pip install onnxruntime' first.")
import onnxruntime as ort
import torch
from scaling_converter import convert_scaled_to_non_scaled
from zipformer import (
Conv2dSubsampling,
RelPositionalEncoding,
Zipformer,
ZipformerEncoder,
ZipformerEncoderLayer,
)
ort.set_default_logger_severity(3)
def test_conv2d_subsampling():
filename = "conv2d_subsampling.onnx"
opset_version = 13
N = 30
T = 50
num_features = 80
d_model = 512
x = torch.rand(N, T, num_features)
encoder_embed = Conv2dSubsampling(num_features, d_model)
encoder_embed.eval()
encoder_embed = convert_scaled_to_non_scaled(encoder_embed, inplace=True)
torch.onnx.export(
encoder_embed,
x,
filename,
verbose=False,
opset_version=opset_version,
input_names=["x"],
output_names=["y"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"y": {0: "N", 1: "T"},
},
)
options = ort.SessionOptions()
options.inter_op_num_threads = 1
options.intra_op_num_threads = 1
session = ort.InferenceSession(
filename,
sess_options=options,
)
input_nodes = session.get_inputs()
assert input_nodes[0].name == "x"
assert input_nodes[0].shape == ["N", "T", num_features]
inputs = {input_nodes[0].name: x.numpy()}
onnx_y = session.run(["y"], inputs)[0]
onnx_y = torch.from_numpy(onnx_y)
torch_y = encoder_embed(x)
assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max()
os.remove(filename)
def test_rel_pos():
filename = "rel_pos.onnx"
opset_version = 13
N = 30
T = 50
num_features = 80
d_model = 512
x = torch.rand(N, T, num_features)
encoder_pos = RelPositionalEncoding(d_model, dropout_rate=0.1)
encoder_pos.eval()
encoder_pos = convert_scaled_to_non_scaled(encoder_pos, inplace=True)
x = x.permute(1, 0, 2)
torch.onnx.export(
encoder_pos,
x,
filename,
verbose=False,
opset_version=opset_version,
input_names=["x"],
output_names=["pos_emb"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"pos_emb": {0: "N", 1: "T"},
},
)
options = ort.SessionOptions()
options.inter_op_num_threads = 1
options.intra_op_num_threads = 1
session = ort.InferenceSession(
filename,
sess_options=options,
)
input_nodes = session.get_inputs()
assert input_nodes[0].name == "x"
assert input_nodes[0].shape == ["N", "T", num_features]
inputs = {input_nodes[0].name: x.numpy()}
onnx_pos_emb = session.run(["pos_emb"], inputs)
onnx_pos_emb = torch.from_numpy(onnx_pos_emb[0])
torch_pos_emb = encoder_pos(x)
assert torch.allclose(onnx_pos_emb, torch_pos_emb, atol=1e-05), (
(onnx_pos_emb - torch_pos_emb).abs().max()
)
print(onnx_pos_emb.abs().sum(), torch_pos_emb.abs().sum())
os.remove(filename)
def test_zipformer_encoder_layer():
filename = "zipformer_encoder_layer.onnx"
opset_version = 13
N = 30
T = 50
d_model = 384
attention_dim = 192
nhead = 8
feedforward_dim = 1024
dropout = 0.1
cnn_module_kernel = 31
pos_dim = 4
x = torch.rand(N, T, d_model)
encoder_pos = RelPositionalEncoding(d_model, dropout)
encoder_pos.eval()
encoder_pos = convert_scaled_to_non_scaled(encoder_pos, inplace=True)
x = x.permute(1, 0, 2)
pos_emb = encoder_pos(x)
encoder_layer = ZipformerEncoderLayer(
d_model,
attention_dim,
nhead,
feedforward_dim,
dropout,
cnn_module_kernel,
pos_dim,
)
encoder_layer.eval()
encoder_layer = convert_scaled_to_non_scaled(encoder_layer, inplace=True)
torch.onnx.export(
encoder_layer,
(x, pos_emb),
filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "pos_emb"],
output_names=["y"],
dynamic_axes={
"x": {0: "T", 1: "N"},
"pos_emb": {0: "N", 1: "T"},
"y": {0: "T", 1: "N"},
},
)
options = ort.SessionOptions()
options.inter_op_num_threads = 1
options.intra_op_num_threads = 1
session = ort.InferenceSession(
filename,
sess_options=options,
)
input_nodes = session.get_inputs()
inputs = {
input_nodes[0].name: x.numpy(),
input_nodes[1].name: pos_emb.numpy(),
}
onnx_y = session.run(["y"], inputs)[0]
onnx_y = torch.from_numpy(onnx_y)
torch_y = encoder_layer(x, pos_emb)
assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max()
print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape)
os.remove(filename)
def test_zipformer_encoder():
filename = "zipformer_encoder.onnx"
opset_version = 13
N = 3
T = 15
d_model = 512
attention_dim = 192
nhead = 8
feedforward_dim = 1024
dropout = 0.1
cnn_module_kernel = 31
pos_dim = 4
num_encoder_layers = 12
warmup_batches = 4000.0
warmup_begin = warmup_batches / (num_encoder_layers + 1)
warmup_end = warmup_batches / (num_encoder_layers + 1)
x = torch.rand(N, T, d_model)
encoder_layer = ZipformerEncoderLayer(
d_model,
attention_dim,
nhead,
feedforward_dim,
dropout,
cnn_module_kernel,
pos_dim,
)
encoder = ZipformerEncoder(
encoder_layer, num_encoder_layers, dropout, warmup_begin, warmup_end
)
encoder.eval()
encoder = convert_scaled_to_non_scaled(encoder, inplace=True)
# jit_model = torch.jit.trace(encoder, (pos_emb))
torch_y = encoder(x)
torch.onnx.export(
encoder,
(x),
filename,
verbose=False,
opset_version=opset_version,
input_names=["x"],
output_names=["y"],
dynamic_axes={
"x": {0: "T", 1: "N"},
"y": {0: "T", 1: "N"},
},
)
options = ort.SessionOptions()
options.inter_op_num_threads = 1
options.intra_op_num_threads = 1
session = ort.InferenceSession(
filename,
sess_options=options,
)
input_nodes = session.get_inputs()
inputs = {
input_nodes[0].name: x.numpy(),
}
onnx_y = session.run(["y"], inputs)[0]
onnx_y = torch.from_numpy(onnx_y)
torch_y = encoder(x)
assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max()
print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape)
os.remove(filename)
def test_zipformer():
filename = "zipformer.onnx"
opset_version = 11
N = 3
T = 15
num_features = 80
x = torch.rand(N, T, num_features)
x_lens = torch.full((N,), fill_value=T, dtype=torch.int64)
zipformer = Zipformer(num_features=num_features)
zipformer.eval()
zipformer = convert_scaled_to_non_scaled(zipformer, inplace=True)
# jit_model = torch.jit.trace(zipformer, (x, x_lens))
torch.onnx.export(
zipformer,
(x, x_lens),
filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "x_lens"],
output_names=["y", "y_lens"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"y": {0: "N", 1: "T"},
"y_lens": {0: "N"},
},
)
options = ort.SessionOptions()
options.inter_op_num_threads = 1
options.intra_op_num_threads = 1
session = ort.InferenceSession(
filename,
sess_options=options,
)
input_nodes = session.get_inputs()
inputs = {
input_nodes[0].name: x.numpy(),
input_nodes[1].name: x_lens.numpy(),
}
onnx_y, onnx_y_lens = session.run(["y", "y_lens"], inputs)
onnx_y = torch.from_numpy(onnx_y)
onnx_y_lens = torch.from_numpy(onnx_y_lens)
torch_y, torch_y_lens = zipformer(x, x_lens)
assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max()
assert torch.allclose(onnx_y_lens, torch_y_lens, atol=1e-05), (
(onnx_y_lens - torch_y_lens).abs().max()
)
print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape)
print(onnx_y_lens, torch_y_lens)
os.remove(filename)
@torch.no_grad()
def main():
test_conv2d_subsampling()
test_rel_pos()
test_zipformer_encoder_layer()
test_zipformer_encoder()
test_zipformer()
if __name__ == "__main__":
torch.manual_seed(20221011)
main()

View File

@ -210,7 +210,7 @@ class Zipformer(EncoderInterface):
(num_frames, batch_size, encoder_dims0) (num_frames, batch_size, encoder_dims0)
""" """
num_encoders = len(self.encoder_dims) num_encoders = len(self.encoder_dims)
if torch.jit.is_scripting() or not self.training: if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing():
return [1.0] * num_encoders return [1.0] * num_encoders
(num_frames0, batch_size, _encoder_dims0) = x.shape (num_frames0, batch_size, _encoder_dims0) = x.shape
@ -293,7 +293,7 @@ class Zipformer(EncoderInterface):
k = self.skip_layers[i] k = self.skip_layers[i]
if isinstance(k, int): if isinstance(k, int):
layer_skip_dropout_prob = self._get_layer_skip_dropout_prob() layer_skip_dropout_prob = self._get_layer_skip_dropout_prob()
if torch.jit.is_scripting(): if torch.jit.is_scripting() or torch.jit.is_tracing():
x = skip_module(outputs[k], x) x = skip_module(outputs[k], x)
elif (not self.training) or random.random() > layer_skip_dropout_prob: elif (not self.training) or random.random() > layer_skip_dropout_prob:
x = skip_module(outputs[k], x) x = skip_module(outputs[k], x)
@ -386,7 +386,7 @@ class ZipformerEncoderLayer(nn.Module):
) )
def get_bypass_scale(self): def get_bypass_scale(self):
if torch.jit.is_scripting() or not self.training: if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing():
return self.bypass_scale return self.bypass_scale
if random.random() < 0.1: if random.random() < 0.1:
# ensure we get grads if self.bypass_scale becomes out of range # ensure we get grads if self.bypass_scale becomes out of range
@ -407,7 +407,7 @@ class ZipformerEncoderLayer(nn.Module):
# return dropout rate for the dynamic modules (self_attn, pooling, convolution); this # return dropout rate for the dynamic modules (self_attn, pooling, convolution); this
# starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable # starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable
# at the beginning, by making the network focus on the feedforward modules. # at the beginning, by making the network focus on the feedforward modules.
if torch.jit.is_scripting() or not self.training: if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing():
return 0.0 return 0.0
warmup_period = 2000.0 warmup_period = 2000.0
initial_dropout_rate = 0.2 initial_dropout_rate = 0.2
@ -452,12 +452,12 @@ class ZipformerEncoderLayer(nn.Module):
dynamic_dropout = self.get_dynamic_dropout_rate() dynamic_dropout = self.get_dynamic_dropout_rate()
# pooling module # pooling module
if torch.jit.is_scripting(): if torch.jit.is_scripting() or torch.jit.is_tracing():
src = src + self.pooling(src, key_padding_mask=src_key_padding_mask) src = src + self.pooling(src, key_padding_mask=src_key_padding_mask)
elif random.random() >= dynamic_dropout: elif random.random() >= dynamic_dropout:
src = src + self.pooling(src, key_padding_mask=src_key_padding_mask) src = src + self.pooling(src, key_padding_mask=src_key_padding_mask)
if torch.jit.is_scripting(): if torch.jit.is_scripting() or torch.jit.is_tracing():
src_att, attn_weights = self.self_attn( src_att, attn_weights = self.self_attn(
src, src,
pos_emb=pos_emb, pos_emb=pos_emb,
@ -658,7 +658,7 @@ class ZipformerEncoder(nn.Module):
pos_emb = self.encoder_pos(src) pos_emb = self.encoder_pos(src)
output = src output = src
if torch.jit.is_scripting(): if torch.jit.is_scripting() or torch.jit.is_tracing():
layers_to_drop = [] layers_to_drop = []
else: else:
rnd_seed = src.numel() + random.randint(0, 1000) rnd_seed = src.numel() + random.randint(0, 1000)
@ -667,7 +667,7 @@ class ZipformerEncoder(nn.Module):
output = output * feature_mask output = output * feature_mask
for i, mod in enumerate(self.layers): for i, mod in enumerate(self.layers):
if not torch.jit.is_scripting(): if not torch.jit.is_scripting() and not torch.jit.is_tracing():
if i in layers_to_drop: if i in layers_to_drop:
continue continue
output = mod( output = mod(
@ -864,7 +864,7 @@ class SimpleCombiner(torch.nn.Module):
assert src1.shape[:-1] == src2.shape[:-1], (src1.shape, src2.shape) assert src1.shape[:-1] == src2.shape[:-1], (src1.shape, src2.shape)
weight1 = self.weight1 weight1 = self.weight1
if not torch.jit.is_scripting(): if not torch.jit.is_scripting() and not torch.jit.is_tracing():
if ( if (
self.training self.training
and random.random() < 0.25 and random.random() < 0.25
@ -1258,6 +1258,16 @@ class RelPositionMultiheadAttention(nn.Module):
# the following .as_strided() expression converts the last axis of pos_weights from relative # the following .as_strided() expression converts the last axis of pos_weights from relative
# to absolute position. I don't know whether I might have got the time-offsets backwards or # to absolute position. I don't know whether I might have got the time-offsets backwards or
# not, but let this code define which way round it is supposed to be. # not, but let this code define which way round it is supposed to be.
if torch.jit.is_tracing():
(batch_size, num_heads, time1, n) = pos_weights.shape
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
cols = torch.arange(seq_len)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
indexes = rows + cols
pos_weights = pos_weights.reshape(-1, n)
pos_weights = torch.gather(pos_weights, dim=1, index=indexes)
pos_weights = pos_weights.reshape(batch_size, num_heads, time1, seq_len)
else:
pos_weights = pos_weights.as_strided( pos_weights = pos_weights.as_strided(
(bsz, num_heads, seq_len, seq_len), (bsz, num_heads, seq_len, seq_len),
( (
@ -1272,7 +1282,7 @@ class RelPositionMultiheadAttention(nn.Module):
# caution: they are really scores at this point. # caution: they are really scores at this point.
attn_output_weights = torch.matmul(q, k) + pos_weights attn_output_weights = torch.matmul(q, k) + pos_weights
if not torch.jit.is_scripting(): if not torch.jit.is_scripting() and not torch.jit.is_tracing():
if training and random.random() < 0.1: if training and random.random() < 0.1:
# This is a harder way of limiting the attention scores to not be too large. # This is a harder way of limiting the attention scores to not be too large.
# It incurs a penalty if any of them has an absolute value greater than 50.0. # It incurs a penalty if any of them has an absolute value greater than 50.0.
@ -1383,7 +1393,7 @@ class RelPositionMultiheadAttention(nn.Module):
# now v: (bsz * num_heads, seq_len, head_dim // 2) # now v: (bsz * num_heads, seq_len, head_dim // 2)
attn_output = torch.bmm(attn_weights, v) attn_output = torch.bmm(attn_weights, v)
if not torch.jit.is_scripting(): if not torch.jit.is_scripting() and not torch.jit.is_tracing():
if random.random() < 0.001 or __name__ == "__main__": if random.random() < 0.001 or __name__ == "__main__":
self._print_attn_stats(attn_weights, attn_output) self._print_attn_stats(attn_weights, attn_output)
@ -1458,6 +1468,9 @@ class PoolingModule(nn.Module):
a Tensor of shape (1, N, C) a Tensor of shape (1, N, C)
""" """
if key_padding_mask is not None: if key_padding_mask is not None:
if torch.jit.is_tracing():
pooling_mask = (~key_padding_mask).to(x.dtype)
else:
pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T) pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T)
pooling_mask = pooling_mask / pooling_mask.sum(dim=1, keepdim=True) pooling_mask = pooling_mask / pooling_mask.sum(dim=1, keepdim=True)
pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1) pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1)