mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Add Zipformer Onnx Support (#778)
* add export script * add zipformer onnx pretrained script * add onnx zipformer test * fix style * add zipformer onnx to workflow * replace is_in_onnx_export with is_tracing * add github.event.label.name == 'onnx' * add is_tracing to necessary conditions * fix pooling_mask * add onnx_check * add onnx_check to scripts * add is_tracing to scaling.py
This commit is contained in:
parent
80cce141b4
commit
0f26edfde9
@ -30,6 +30,15 @@ ln -s pretrained.pt epoch-99.pt
|
||||
ls -lh *.pt
|
||||
popd
|
||||
|
||||
log "Test exporting to ONNX format"
|
||||
./pruned_transducer_stateless7/export.py \
|
||||
--exp-dir $repo/exp \
|
||||
--use-averaged-model false \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--onnx 1
|
||||
|
||||
log "Export to torchscript model"
|
||||
./pruned_transducer_stateless7/export.py \
|
||||
--exp-dir $repo/exp \
|
||||
@ -41,6 +50,27 @@ log "Export to torchscript model"
|
||||
|
||||
ls -lh $repo/exp/*.pt
|
||||
|
||||
log "Decode with ONNX models"
|
||||
|
||||
./pruned_transducer_stateless7/onnx_check.py \
|
||||
--jit-filename $repo/exp/cpu_jit.pt \
|
||||
--onnx-encoder-filename $repo/exp/encoder.onnx \
|
||||
--onnx-decoder-filename $repo/exp/decoder.onnx \
|
||||
--onnx-joiner-filename $repo/exp/joiner.onnx \
|
||||
--onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \
|
||||
--onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx
|
||||
|
||||
./pruned_transducer_stateless7/onnx_pretrained.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--encoder-model-filename $repo/exp/encoder.onnx \
|
||||
--decoder-model-filename $repo/exp/decoder.onnx \
|
||||
--joiner-model-filename $repo/exp/joiner.onnx \
|
||||
--joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \
|
||||
--joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
|
||||
log "Decode with models exported by torch.jit.script()"
|
||||
|
||||
./pruned_transducer_stateless7/jit_pretrained.py \
|
||||
|
@ -39,7 +39,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
run_librispeech_2022_11_11_zipformer:
|
||||
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
|
@ -41,7 +41,31 @@ Check
|
||||
https://github.com/k2-fsa/sherpa
|
||||
for how to use the exported models outside of icefall.
|
||||
|
||||
(2) Export `model.state_dict()`
|
||||
(2) Export to ONNX format
|
||||
|
||||
./pruned_transducer_stateless7/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--onnx 1
|
||||
|
||||
It will generate the following files in the given `exp_dir`.
|
||||
Check `onnx_check.py` for how to use them.
|
||||
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
- joiner.onnx
|
||||
- joiner_encoder_proj.onnx
|
||||
- joiner_decoder_proj.onnx
|
||||
|
||||
Please see ./onnx_pretrained.py for usage of the generated files
|
||||
|
||||
Check
|
||||
https://github.com/k2-fsa/sherpa-onnx
|
||||
for how to use the exported models outside of icefall.
|
||||
|
||||
(3) Export `model.state_dict()`
|
||||
|
||||
./pruned_transducer_stateless7/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
@ -172,6 +196,23 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""If True, --jit is ignored and it exports the model
|
||||
to onnx format. It will generate the following files:
|
||||
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
- joiner.onnx
|
||||
- joiner_encoder_proj.onnx
|
||||
- joiner_decoder_proj.onnx
|
||||
|
||||
Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
@ -184,6 +225,204 @@ def get_parser():
|
||||
return parser
|
||||
|
||||
|
||||
def export_encoder_model_onnx(
|
||||
encoder_model: nn.Module,
|
||||
encoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the given encoder model to ONNX format.
|
||||
The exported model has two inputs:
|
||||
|
||||
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
||||
- x_lens, a tensor of shape (N,); dtype is torch.int64
|
||||
|
||||
and it has two outputs:
|
||||
|
||||
- encoder_out, a tensor of shape (N, T, C)
|
||||
- encoder_out_lens, a tensor of shape (N,)
|
||||
|
||||
Note: The warmup argument is fixed to 1.
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
The input encoder model
|
||||
encoder_filename:
|
||||
The filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
x = torch.zeros(1, 101, 80, dtype=torch.float32)
|
||||
x_lens = torch.tensor([101], dtype=torch.int64)
|
||||
|
||||
# encoder_model = torch.jit.script(encoder_model)
|
||||
# It throws the following error for the above statement
|
||||
#
|
||||
# RuntimeError: Exporting the operator __is_ to ONNX opset version
|
||||
# 11 is not supported. Please feel free to request support or
|
||||
# submit a pull request on PyTorch GitHub.
|
||||
#
|
||||
# I cannot find which statement causes the above error.
|
||||
# torch.onnx.export() will use torch.jit.trace() internally, which
|
||||
# works well for the current reworked model
|
||||
torch.onnx.export(
|
||||
encoder_model,
|
||||
(x, x_lens),
|
||||
encoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x", "x_lens"],
|
||||
output_names=["encoder_out", "encoder_out_lens"],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "T"},
|
||||
"x_lens": {0: "N"},
|
||||
"encoder_out": {0: "N", 1: "T"},
|
||||
"encoder_out_lens": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {encoder_filename}")
|
||||
|
||||
|
||||
def export_decoder_model_onnx(
|
||||
decoder_model: nn.Module,
|
||||
decoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the decoder model to ONNX format.
|
||||
|
||||
The exported model has one input:
|
||||
|
||||
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||
|
||||
and has one output:
|
||||
|
||||
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
|
||||
|
||||
Note: The argument need_pad is fixed to False.
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The decoder model to be exported.
|
||||
decoder_filename:
|
||||
Filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||
need_pad = False # Always False, so we can use torch.jit.trace() here
|
||||
# Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script()
|
||||
# in this case
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
(y, need_pad),
|
||||
decoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["y", "need_pad"],
|
||||
output_names=["decoder_out"],
|
||||
dynamic_axes={
|
||||
"y": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
|
||||
def export_joiner_model_onnx(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the joiner model to ONNX format.
|
||||
The exported joiner model has two inputs:
|
||||
|
||||
- projected_encoder_out: a tensor of shape (N, joiner_dim)
|
||||
- projected_decoder_out: a tensor of shape (N, joiner_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- logit: a tensor of shape (N, vocab_size)
|
||||
|
||||
The exported encoder_proj model has one input:
|
||||
|
||||
- encoder_out: a tensor of shape (N, encoder_out_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- projected_encoder_out: a tensor of shape (N, joiner_dim)
|
||||
|
||||
The exported decoder_proj model has one input:
|
||||
|
||||
- decoder_out: a tensor of shape (N, decoder_out_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- projected_decoder_out: a tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
|
||||
decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
|
||||
|
||||
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||
joiner_dim = joiner_model.decoder_proj.weight.shape[0]
|
||||
|
||||
projected_encoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32)
|
||||
projected_decoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32)
|
||||
|
||||
project_input = False
|
||||
# Note: It uses torch.jit.trace() internally
|
||||
torch.onnx.export(
|
||||
joiner_model,
|
||||
(projected_encoder_out, projected_decoder_out, project_input),
|
||||
joiner_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=[
|
||||
"encoder_out",
|
||||
"decoder_out",
|
||||
"project_input",
|
||||
],
|
||||
output_names=["logit"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
"logit": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||
torch.onnx.export(
|
||||
joiner_model.encoder_proj,
|
||||
encoder_out,
|
||||
encoder_proj_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["encoder_out"],
|
||||
output_names=["projected_encoder_out"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"projected_encoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {encoder_proj_filename}")
|
||||
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
torch.onnx.export(
|
||||
joiner_model.decoder_proj,
|
||||
decoder_out,
|
||||
decoder_proj_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["decoder_out"],
|
||||
output_names=["projected_decoder_out"],
|
||||
dynamic_axes={
|
||||
"decoder_out": {0: "N"},
|
||||
"projected_decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {decoder_proj_filename}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
@ -292,7 +531,31 @@ def main():
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
if params.jit is True:
|
||||
if params.onnx is True:
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
opset_version = 13
|
||||
logging.info("Exporting to onnx format")
|
||||
encoder_filename = params.exp_dir / "encoder.onnx"
|
||||
export_encoder_model_onnx(
|
||||
model.encoder,
|
||||
encoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
decoder_filename = params.exp_dir / "decoder.onnx"
|
||||
export_decoder_model_onnx(
|
||||
model.decoder,
|
||||
decoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
joiner_filename = params.exp_dir / "joiner.onnx"
|
||||
export_joiner_model_onnx(
|
||||
model.joiner,
|
||||
joiner_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
elif params.jit is True:
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
|
286
egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py
Executable file
286
egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py
Executable file
@ -0,0 +1,286 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script checks that exported onnx models produce the same output
|
||||
with the given torchscript model for the same input.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
|
||||
from icefall import is_module_available
|
||||
|
||||
if not is_module_available("onnxruntime"):
|
||||
raise ValueError("Please 'pip install onnxruntime' first.")
|
||||
|
||||
|
||||
ort.set_default_logger_severity(3)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the torchscript model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-encoder-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the onnx encoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-decoder-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the onnx decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-joiner-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the onnx joiner model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-joiner-encoder-proj-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the onnx joiner encoder projection model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-joiner-decoder-proj-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the onnx joiner decoder projection model",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def test_encoder(
|
||||
model: torch.jit.ScriptModule,
|
||||
encoder_session: ort.InferenceSession,
|
||||
):
|
||||
inputs = encoder_session.get_inputs()
|
||||
outputs = encoder_session.get_outputs()
|
||||
input_names = [n.name for n in inputs]
|
||||
output_names = [n.name for n in outputs]
|
||||
|
||||
assert inputs[0].shape == ["N", "T", 80]
|
||||
assert inputs[1].shape == ["N"]
|
||||
|
||||
for N in [1, 5]:
|
||||
for T in [12, 50]:
|
||||
print("N, T", N, T)
|
||||
x = torch.rand(N, T, 80, dtype=torch.float32)
|
||||
x_lens = torch.randint(low=10, high=T + 1, size=(N,))
|
||||
x_lens[0] = T
|
||||
|
||||
encoder_inputs = {
|
||||
input_names[0]: x.numpy(),
|
||||
input_names[1]: x_lens.numpy(),
|
||||
}
|
||||
|
||||
torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens)
|
||||
|
||||
encoder_out, encoder_out_lens = encoder_session.run(
|
||||
output_names,
|
||||
encoder_inputs,
|
||||
)
|
||||
|
||||
torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens)
|
||||
|
||||
encoder_out = torch.from_numpy(encoder_out)
|
||||
assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), (
|
||||
(encoder_out - torch_encoder_out).abs().max(),
|
||||
encoder_out.shape,
|
||||
torch_encoder_out.shape,
|
||||
)
|
||||
|
||||
|
||||
def test_decoder(
|
||||
model: torch.jit.ScriptModule,
|
||||
decoder_session: ort.InferenceSession,
|
||||
):
|
||||
inputs = decoder_session.get_inputs()
|
||||
outputs = decoder_session.get_outputs()
|
||||
input_names = [n.name for n in inputs]
|
||||
output_names = [n.name for n in outputs]
|
||||
|
||||
assert inputs[0].shape == ["N", 2]
|
||||
for N in [1, 5, 10]:
|
||||
y = torch.randint(low=1, high=500, size=(10, 2))
|
||||
|
||||
decoder_inputs = {input_names[0]: y.numpy()}
|
||||
decoder_out = decoder_session.run(
|
||||
output_names,
|
||||
decoder_inputs,
|
||||
)[0]
|
||||
decoder_out = torch.from_numpy(decoder_out)
|
||||
|
||||
torch_decoder_out = model.decoder(y, need_pad=False)
|
||||
assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), (
|
||||
(decoder_out - torch_decoder_out).abs().max()
|
||||
)
|
||||
|
||||
|
||||
def test_joiner(
|
||||
model: torch.jit.ScriptModule,
|
||||
joiner_session: ort.InferenceSession,
|
||||
joiner_encoder_proj_session: ort.InferenceSession,
|
||||
joiner_decoder_proj_session: ort.InferenceSession,
|
||||
):
|
||||
joiner_inputs = joiner_session.get_inputs()
|
||||
joiner_outputs = joiner_session.get_outputs()
|
||||
joiner_input_names = [n.name for n in joiner_inputs]
|
||||
joiner_output_names = [n.name for n in joiner_outputs]
|
||||
|
||||
assert joiner_inputs[0].shape == ["N", 1, 1, 512]
|
||||
assert joiner_inputs[1].shape == ["N", 1, 1, 512]
|
||||
|
||||
joiner_encoder_proj_inputs = joiner_encoder_proj_session.get_inputs()
|
||||
encoder_proj_input_name = joiner_encoder_proj_inputs[0].name
|
||||
|
||||
assert joiner_encoder_proj_inputs[0].shape == ["N", 384]
|
||||
|
||||
joiner_encoder_proj_outputs = joiner_encoder_proj_session.get_outputs()
|
||||
encoder_proj_output_name = joiner_encoder_proj_outputs[0].name
|
||||
|
||||
joiner_decoder_proj_inputs = joiner_decoder_proj_session.get_inputs()
|
||||
decoder_proj_input_name = joiner_decoder_proj_inputs[0].name
|
||||
|
||||
assert joiner_decoder_proj_inputs[0].shape == ["N", 512]
|
||||
|
||||
joiner_decoder_proj_outputs = joiner_decoder_proj_session.get_outputs()
|
||||
decoder_proj_output_name = joiner_decoder_proj_outputs[0].name
|
||||
|
||||
for N in [1, 5, 10]:
|
||||
encoder_out = torch.rand(N, 384)
|
||||
decoder_out = torch.rand(N, 512)
|
||||
|
||||
projected_encoder_out = torch.rand(N, 1, 1, 512)
|
||||
projected_decoder_out = torch.rand(N, 1, 1, 512)
|
||||
|
||||
joiner_inputs = {
|
||||
joiner_input_names[0]: projected_encoder_out.numpy(),
|
||||
joiner_input_names[1]: projected_decoder_out.numpy(),
|
||||
}
|
||||
joiner_out = joiner_session.run(joiner_output_names, joiner_inputs)[0]
|
||||
joiner_out = torch.from_numpy(joiner_out)
|
||||
|
||||
torch_joiner_out = model.joiner(
|
||||
projected_encoder_out,
|
||||
projected_decoder_out,
|
||||
project_input=False,
|
||||
)
|
||||
assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), (
|
||||
(joiner_out - torch_joiner_out).abs().max()
|
||||
)
|
||||
|
||||
# Now test encoder_proj
|
||||
joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()}
|
||||
joiner_encoder_proj_out = joiner_encoder_proj_session.run(
|
||||
[encoder_proj_output_name], joiner_encoder_proj_inputs
|
||||
)[0]
|
||||
joiner_encoder_proj_out = torch.from_numpy(joiner_encoder_proj_out)
|
||||
|
||||
torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out)
|
||||
assert torch.allclose(
|
||||
joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5
|
||||
), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max())
|
||||
|
||||
# Now test decoder_proj
|
||||
joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()}
|
||||
joiner_decoder_proj_out = joiner_decoder_proj_session.run(
|
||||
[decoder_proj_output_name], joiner_decoder_proj_inputs
|
||||
)[0]
|
||||
joiner_decoder_proj_out = torch.from_numpy(joiner_decoder_proj_out)
|
||||
|
||||
torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out)
|
||||
assert torch.allclose(
|
||||
joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5
|
||||
), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max())
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
model = torch.jit.load(args.jit_filename)
|
||||
|
||||
options = ort.SessionOptions()
|
||||
options.inter_op_num_threads = 1
|
||||
options.intra_op_num_threads = 1
|
||||
|
||||
logging.info("Test encoder")
|
||||
encoder_session = ort.InferenceSession(
|
||||
args.onnx_encoder_filename,
|
||||
sess_options=options,
|
||||
)
|
||||
test_encoder(model, encoder_session)
|
||||
|
||||
logging.info("Test decoder")
|
||||
decoder_session = ort.InferenceSession(
|
||||
args.onnx_decoder_filename,
|
||||
sess_options=options,
|
||||
)
|
||||
test_decoder(model, decoder_session)
|
||||
|
||||
logging.info("Test joiner")
|
||||
joiner_session = ort.InferenceSession(
|
||||
args.onnx_joiner_filename,
|
||||
sess_options=options,
|
||||
)
|
||||
joiner_encoder_proj_session = ort.InferenceSession(
|
||||
args.onnx_joiner_encoder_proj_filename,
|
||||
sess_options=options,
|
||||
)
|
||||
joiner_decoder_proj_session = ort.InferenceSession(
|
||||
args.onnx_joiner_decoder_proj_filename,
|
||||
sess_options=options,
|
||||
)
|
||||
test_joiner(
|
||||
model,
|
||||
joiner_session,
|
||||
joiner_encoder_proj_session,
|
||||
joiner_decoder_proj_session,
|
||||
)
|
||||
logging.info("Finished checking ONNX models")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20220727)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
388
egs/librispeech/ASR/pruned_transducer_stateless7/onnx_pretrained.py
Executable file
388
egs/librispeech/ASR/pruned_transducer_stateless7/onnx_pretrained.py
Executable file
@ -0,0 +1,388 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads ONNX models and uses them to decode waves.
|
||||
You can use the following command to get the exported models:
|
||||
|
||||
./pruned_transducer_stateless7/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--onnx 1
|
||||
|
||||
Usage of this script:
|
||||
|
||||
./pruned_transducer_stateless7/onnx_pretrained.py \
|
||||
--encoder-model-filename ./pruned_transducer_stateless7/exp/encoder.onnx \
|
||||
--decoder-model-filename ./pruned_transducer_stateless7/exp/decoder.onnx \
|
||||
--joiner-model-filename ./pruned_transducer_stateless7/exp/joiner.onnx \
|
||||
--joiner-encoder-proj-model-filename ./pruned_transducer_stateless7/exp/joiner_encoder_proj.onnx \
|
||||
--joiner-decoder-proj-model-filename ./pruned_transducer_stateless7/exp/joiner_decoder_proj.onnx \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import kaldifeat
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the encoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the decoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-encoder-proj-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner encoder_proj onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-decoder-proj-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner decoder_proj onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Context size of the decoder model",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
def greedy_search(
|
||||
decoder: ort.InferenceSession,
|
||||
joiner: ort.InferenceSession,
|
||||
joiner_encoder_proj: ort.InferenceSession,
|
||||
joiner_decoder_proj: ort.InferenceSession,
|
||||
encoder_out: np.ndarray,
|
||||
encoder_out_lens: np.ndarray,
|
||||
context_size: int,
|
||||
) -> List[List[int]]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
decoder:
|
||||
The decoder model.
|
||||
joiner:
|
||||
The joiner model.
|
||||
joiner_encoder_proj:
|
||||
The joiner encoder projection model.
|
||||
joiner_decoder_proj:
|
||||
The joiner decoder projection model.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,).
|
||||
context_size:
|
||||
The context size of the decoder model.
|
||||
Returns:
|
||||
Return the decoded results for each utterance.
|
||||
"""
|
||||
encoder_out = torch.from_numpy(encoder_out)
|
||||
encoder_out_lens = torch.from_numpy(encoder_out_lens)
|
||||
assert encoder_out.ndim == 3
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
lengths=encoder_out_lens.cpu(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False,
|
||||
)
|
||||
|
||||
projected_encoder_out = joiner_encoder_proj.run(
|
||||
[joiner_encoder_proj.get_outputs()[0].name],
|
||||
{joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()},
|
||||
)[0]
|
||||
|
||||
blank_id = 0 # hard-code to 0
|
||||
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||
|
||||
decoder_input_nodes = decoder.get_inputs()
|
||||
decoder_output_nodes = decoder.get_outputs()
|
||||
|
||||
joiner_input_nodes = joiner.get_inputs()
|
||||
joiner_output_nodes = joiner.get_outputs()
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
hyps,
|
||||
dtype=torch.int64,
|
||||
) # (N, context_size)
|
||||
|
||||
decoder_out = decoder.run(
|
||||
[decoder_output_nodes[0].name],
|
||||
{
|
||||
decoder_input_nodes[0].name: decoder_input.numpy(),
|
||||
},
|
||||
)[0].squeeze(1)
|
||||
projected_decoder_out = joiner_decoder_proj.run(
|
||||
[joiner_decoder_proj.get_outputs()[0].name],
|
||||
{joiner_decoder_proj.get_inputs()[0].name: decoder_out},
|
||||
)[0]
|
||||
|
||||
projected_decoder_out = torch.from_numpy(projected_decoder_out)
|
||||
|
||||
offset = 0
|
||||
for batch_size in batch_size_list:
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = projected_encoder_out[start:end]
|
||||
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||
offset = end
|
||||
|
||||
projected_decoder_out = projected_decoder_out[:batch_size]
|
||||
|
||||
logits = joiner.run(
|
||||
[joiner_output_nodes[0].name],
|
||||
{
|
||||
joiner_input_nodes[0].name: np.expand_dims(
|
||||
np.expand_dims(current_encoder_out, axis=1), axis=1
|
||||
),
|
||||
joiner_input_nodes[1]
|
||||
.name: projected_decoder_out.unsqueeze(1)
|
||||
.unsqueeze(1)
|
||||
.numpy(),
|
||||
},
|
||||
)[0]
|
||||
logits = torch.from_numpy(logits).squeeze(1).squeeze(1)
|
||||
# logits'shape (batch_size, vocab_size)
|
||||
|
||||
assert logits.ndim == 2, logits.shape
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
if v != blank_id:
|
||||
hyps[i].append(v)
|
||||
emitted = True
|
||||
if emitted:
|
||||
# update decoder output
|
||||
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||
decoder_input = torch.tensor(
|
||||
decoder_input,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = decoder.run(
|
||||
[decoder_output_nodes[0].name],
|
||||
{
|
||||
decoder_input_nodes[0].name: decoder_input.numpy(),
|
||||
},
|
||||
)[0].squeeze(1)
|
||||
projected_decoder_out = joiner_decoder_proj.run(
|
||||
[joiner_decoder_proj.get_outputs()[0].name],
|
||||
{joiner_decoder_proj.get_inputs()[0].name: decoder_out},
|
||||
)[0]
|
||||
projected_decoder_out = torch.from_numpy(projected_decoder_out)
|
||||
|
||||
sorted_ans = [h[context_size:] for h in hyps]
|
||||
ans = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
encoder = ort.InferenceSession(
|
||||
args.encoder_model_filename,
|
||||
sess_options=session_opts,
|
||||
)
|
||||
|
||||
decoder = ort.InferenceSession(
|
||||
args.decoder_model_filename,
|
||||
sess_options=session_opts,
|
||||
)
|
||||
|
||||
joiner = ort.InferenceSession(
|
||||
args.joiner_model_filename,
|
||||
sess_options=session_opts,
|
||||
)
|
||||
|
||||
joiner_encoder_proj = ort.InferenceSession(
|
||||
args.joiner_encoder_proj_model_filename,
|
||||
sess_options=session_opts,
|
||||
)
|
||||
|
||||
joiner_decoder_proj = ort.InferenceSession(
|
||||
args.joiner_decoder_proj_model_filename,
|
||||
sess_options=session_opts,
|
||||
)
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(args.bpe_model)
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = "cpu"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = args.sample_rate
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=args.sound_files,
|
||||
expected_sample_rate=args.sample_rate,
|
||||
)
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features,
|
||||
batch_first=True,
|
||||
padding_value=math.log(1e-10),
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
|
||||
|
||||
encoder_input_nodes = encoder.get_inputs()
|
||||
encoder_out_nodes = encoder.get_outputs()
|
||||
encoder_out, encoder_out_lens = encoder.run(
|
||||
[encoder_out_nodes[0].name, encoder_out_nodes[1].name],
|
||||
{
|
||||
encoder_input_nodes[0].name: features.numpy(),
|
||||
encoder_input_nodes[1].name: feature_lengths.numpy(),
|
||||
},
|
||||
)
|
||||
|
||||
hyps = greedy_search(
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
joiner_encoder_proj=joiner_encoder_proj,
|
||||
joiner_decoder_proj=joiner_decoder_proj,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
context_size=args.context_size,
|
||||
)
|
||||
s = "\n"
|
||||
for filename, hyp in zip(args.sound_files, hyps):
|
||||
words = sp.decode(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -261,7 +261,7 @@ class RandomGrad(torch.nn.Module):
|
||||
self.min_abs = min_abs
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
if torch.jit.is_scripting() or not self.training:
|
||||
if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing():
|
||||
return x
|
||||
else:
|
||||
return RandomGradFunction.apply(x, self.min_abs)
|
||||
@ -530,7 +530,7 @@ class ActivationBalancer(torch.nn.Module):
|
||||
self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if torch.jit.is_scripting() or not x.requires_grad:
|
||||
if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing():
|
||||
return _no_op(x)
|
||||
|
||||
count = self.cpu_count
|
||||
@ -790,7 +790,7 @@ def with_loss(x, y):
|
||||
|
||||
|
||||
def _no_op(x: Tensor) -> Tensor:
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
return x
|
||||
else:
|
||||
# a no-op function that will have a node in the autograd graph,
|
||||
@ -862,6 +862,7 @@ class MaxEig(torch.nn.Module):
|
||||
torch.jit.is_scripting()
|
||||
or self.max_var_per_eig <= 0
|
||||
or random.random() > self.cur_prob
|
||||
or torch.jit.is_tracing()
|
||||
):
|
||||
return _no_op(x)
|
||||
|
||||
|
374
egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py
Normal file
374
egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py
Normal file
@ -0,0 +1,374 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file is to test that models can be exported to onnx.
|
||||
"""
|
||||
import os
|
||||
|
||||
from icefall import is_module_available
|
||||
|
||||
if not is_module_available("onnxruntime"):
|
||||
raise ValueError("Please 'pip install onnxruntime' first.")
|
||||
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from zipformer import (
|
||||
Conv2dSubsampling,
|
||||
RelPositionalEncoding,
|
||||
Zipformer,
|
||||
ZipformerEncoder,
|
||||
ZipformerEncoderLayer,
|
||||
)
|
||||
|
||||
ort.set_default_logger_severity(3)
|
||||
|
||||
|
||||
def test_conv2d_subsampling():
|
||||
filename = "conv2d_subsampling.onnx"
|
||||
opset_version = 13
|
||||
N = 30
|
||||
T = 50
|
||||
num_features = 80
|
||||
d_model = 512
|
||||
x = torch.rand(N, T, num_features)
|
||||
|
||||
encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||
encoder_embed.eval()
|
||||
encoder_embed = convert_scaled_to_non_scaled(encoder_embed, inplace=True)
|
||||
|
||||
torch.onnx.export(
|
||||
encoder_embed,
|
||||
x,
|
||||
filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x"],
|
||||
output_names=["y"],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "T"},
|
||||
"y": {0: "N", 1: "T"},
|
||||
},
|
||||
)
|
||||
|
||||
options = ort.SessionOptions()
|
||||
options.inter_op_num_threads = 1
|
||||
options.intra_op_num_threads = 1
|
||||
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
assert input_nodes[0].name == "x"
|
||||
assert input_nodes[0].shape == ["N", "T", num_features]
|
||||
|
||||
inputs = {input_nodes[0].name: x.numpy()}
|
||||
|
||||
onnx_y = session.run(["y"], inputs)[0]
|
||||
|
||||
onnx_y = torch.from_numpy(onnx_y)
|
||||
torch_y = encoder_embed(x)
|
||||
assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max()
|
||||
|
||||
os.remove(filename)
|
||||
|
||||
|
||||
def test_rel_pos():
|
||||
filename = "rel_pos.onnx"
|
||||
|
||||
opset_version = 13
|
||||
N = 30
|
||||
T = 50
|
||||
num_features = 80
|
||||
d_model = 512
|
||||
x = torch.rand(N, T, num_features)
|
||||
|
||||
encoder_pos = RelPositionalEncoding(d_model, dropout_rate=0.1)
|
||||
encoder_pos.eval()
|
||||
encoder_pos = convert_scaled_to_non_scaled(encoder_pos, inplace=True)
|
||||
|
||||
x = x.permute(1, 0, 2)
|
||||
|
||||
torch.onnx.export(
|
||||
encoder_pos,
|
||||
x,
|
||||
filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x"],
|
||||
output_names=["pos_emb"],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "T"},
|
||||
"pos_emb": {0: "N", 1: "T"},
|
||||
},
|
||||
)
|
||||
|
||||
options = ort.SessionOptions()
|
||||
options.inter_op_num_threads = 1
|
||||
options.intra_op_num_threads = 1
|
||||
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
assert input_nodes[0].name == "x"
|
||||
assert input_nodes[0].shape == ["N", "T", num_features]
|
||||
|
||||
inputs = {input_nodes[0].name: x.numpy()}
|
||||
onnx_pos_emb = session.run(["pos_emb"], inputs)
|
||||
onnx_pos_emb = torch.from_numpy(onnx_pos_emb[0])
|
||||
|
||||
torch_pos_emb = encoder_pos(x)
|
||||
assert torch.allclose(onnx_pos_emb, torch_pos_emb, atol=1e-05), (
|
||||
(onnx_pos_emb - torch_pos_emb).abs().max()
|
||||
)
|
||||
print(onnx_pos_emb.abs().sum(), torch_pos_emb.abs().sum())
|
||||
|
||||
os.remove(filename)
|
||||
|
||||
|
||||
def test_zipformer_encoder_layer():
|
||||
filename = "zipformer_encoder_layer.onnx"
|
||||
opset_version = 13
|
||||
N = 30
|
||||
T = 50
|
||||
|
||||
d_model = 384
|
||||
attention_dim = 192
|
||||
nhead = 8
|
||||
feedforward_dim = 1024
|
||||
dropout = 0.1
|
||||
cnn_module_kernel = 31
|
||||
pos_dim = 4
|
||||
|
||||
x = torch.rand(N, T, d_model)
|
||||
|
||||
encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||
encoder_pos.eval()
|
||||
encoder_pos = convert_scaled_to_non_scaled(encoder_pos, inplace=True)
|
||||
|
||||
x = x.permute(1, 0, 2)
|
||||
pos_emb = encoder_pos(x)
|
||||
|
||||
encoder_layer = ZipformerEncoderLayer(
|
||||
d_model,
|
||||
attention_dim,
|
||||
nhead,
|
||||
feedforward_dim,
|
||||
dropout,
|
||||
cnn_module_kernel,
|
||||
pos_dim,
|
||||
)
|
||||
encoder_layer.eval()
|
||||
encoder_layer = convert_scaled_to_non_scaled(encoder_layer, inplace=True)
|
||||
|
||||
torch.onnx.export(
|
||||
encoder_layer,
|
||||
(x, pos_emb),
|
||||
filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x", "pos_emb"],
|
||||
output_names=["y"],
|
||||
dynamic_axes={
|
||||
"x": {0: "T", 1: "N"},
|
||||
"pos_emb": {0: "N", 1: "T"},
|
||||
"y": {0: "T", 1: "N"},
|
||||
},
|
||||
)
|
||||
|
||||
options = ort.SessionOptions()
|
||||
options.inter_op_num_threads = 1
|
||||
options.intra_op_num_threads = 1
|
||||
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
inputs = {
|
||||
input_nodes[0].name: x.numpy(),
|
||||
input_nodes[1].name: pos_emb.numpy(),
|
||||
}
|
||||
onnx_y = session.run(["y"], inputs)[0]
|
||||
onnx_y = torch.from_numpy(onnx_y)
|
||||
|
||||
torch_y = encoder_layer(x, pos_emb)
|
||||
assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max()
|
||||
|
||||
print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape)
|
||||
|
||||
os.remove(filename)
|
||||
|
||||
|
||||
def test_zipformer_encoder():
|
||||
filename = "zipformer_encoder.onnx"
|
||||
|
||||
opset_version = 13
|
||||
N = 3
|
||||
T = 15
|
||||
|
||||
d_model = 512
|
||||
attention_dim = 192
|
||||
nhead = 8
|
||||
feedforward_dim = 1024
|
||||
dropout = 0.1
|
||||
cnn_module_kernel = 31
|
||||
pos_dim = 4
|
||||
num_encoder_layers = 12
|
||||
|
||||
warmup_batches = 4000.0
|
||||
warmup_begin = warmup_batches / (num_encoder_layers + 1)
|
||||
warmup_end = warmup_batches / (num_encoder_layers + 1)
|
||||
|
||||
x = torch.rand(N, T, d_model)
|
||||
|
||||
encoder_layer = ZipformerEncoderLayer(
|
||||
d_model,
|
||||
attention_dim,
|
||||
nhead,
|
||||
feedforward_dim,
|
||||
dropout,
|
||||
cnn_module_kernel,
|
||||
pos_dim,
|
||||
)
|
||||
encoder = ZipformerEncoder(
|
||||
encoder_layer, num_encoder_layers, dropout, warmup_begin, warmup_end
|
||||
)
|
||||
encoder.eval()
|
||||
encoder = convert_scaled_to_non_scaled(encoder, inplace=True)
|
||||
|
||||
# jit_model = torch.jit.trace(encoder, (pos_emb))
|
||||
|
||||
torch_y = encoder(x)
|
||||
|
||||
torch.onnx.export(
|
||||
encoder,
|
||||
(x),
|
||||
filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x"],
|
||||
output_names=["y"],
|
||||
dynamic_axes={
|
||||
"x": {0: "T", 1: "N"},
|
||||
"y": {0: "T", 1: "N"},
|
||||
},
|
||||
)
|
||||
|
||||
options = ort.SessionOptions()
|
||||
options.inter_op_num_threads = 1
|
||||
options.intra_op_num_threads = 1
|
||||
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
inputs = {
|
||||
input_nodes[0].name: x.numpy(),
|
||||
}
|
||||
onnx_y = session.run(["y"], inputs)[0]
|
||||
onnx_y = torch.from_numpy(onnx_y)
|
||||
|
||||
torch_y = encoder(x)
|
||||
assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max()
|
||||
|
||||
print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape)
|
||||
|
||||
os.remove(filename)
|
||||
|
||||
|
||||
def test_zipformer():
|
||||
filename = "zipformer.onnx"
|
||||
opset_version = 11
|
||||
N = 3
|
||||
T = 15
|
||||
num_features = 80
|
||||
x = torch.rand(N, T, num_features)
|
||||
x_lens = torch.full((N,), fill_value=T, dtype=torch.int64)
|
||||
|
||||
zipformer = Zipformer(num_features=num_features)
|
||||
zipformer.eval()
|
||||
zipformer = convert_scaled_to_non_scaled(zipformer, inplace=True)
|
||||
|
||||
# jit_model = torch.jit.trace(zipformer, (x, x_lens))
|
||||
torch.onnx.export(
|
||||
zipformer,
|
||||
(x, x_lens),
|
||||
filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x", "x_lens"],
|
||||
output_names=["y", "y_lens"],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "T"},
|
||||
"x_lens": {0: "N"},
|
||||
"y": {0: "N", 1: "T"},
|
||||
"y_lens": {0: "N"},
|
||||
},
|
||||
)
|
||||
options = ort.SessionOptions()
|
||||
options.inter_op_num_threads = 1
|
||||
options.intra_op_num_threads = 1
|
||||
|
||||
session = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=options,
|
||||
)
|
||||
|
||||
input_nodes = session.get_inputs()
|
||||
inputs = {
|
||||
input_nodes[0].name: x.numpy(),
|
||||
input_nodes[1].name: x_lens.numpy(),
|
||||
}
|
||||
onnx_y, onnx_y_lens = session.run(["y", "y_lens"], inputs)
|
||||
onnx_y = torch.from_numpy(onnx_y)
|
||||
onnx_y_lens = torch.from_numpy(onnx_y_lens)
|
||||
|
||||
torch_y, torch_y_lens = zipformer(x, x_lens)
|
||||
assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max()
|
||||
|
||||
assert torch.allclose(onnx_y_lens, torch_y_lens, atol=1e-05), (
|
||||
(onnx_y_lens - torch_y_lens).abs().max()
|
||||
)
|
||||
print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape)
|
||||
print(onnx_y_lens, torch_y_lens)
|
||||
|
||||
os.remove(filename)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
test_conv2d_subsampling()
|
||||
test_rel_pos()
|
||||
test_zipformer_encoder_layer()
|
||||
test_zipformer_encoder()
|
||||
test_zipformer()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20221011)
|
||||
main()
|
@ -210,7 +210,7 @@ class Zipformer(EncoderInterface):
|
||||
(num_frames, batch_size, encoder_dims0)
|
||||
"""
|
||||
num_encoders = len(self.encoder_dims)
|
||||
if torch.jit.is_scripting() or not self.training:
|
||||
if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing():
|
||||
return [1.0] * num_encoders
|
||||
|
||||
(num_frames0, batch_size, _encoder_dims0) = x.shape
|
||||
@ -293,7 +293,7 @@ class Zipformer(EncoderInterface):
|
||||
k = self.skip_layers[i]
|
||||
if isinstance(k, int):
|
||||
layer_skip_dropout_prob = self._get_layer_skip_dropout_prob()
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
x = skip_module(outputs[k], x)
|
||||
elif (not self.training) or random.random() > layer_skip_dropout_prob:
|
||||
x = skip_module(outputs[k], x)
|
||||
@ -386,7 +386,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
)
|
||||
|
||||
def get_bypass_scale(self):
|
||||
if torch.jit.is_scripting() or not self.training:
|
||||
if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing():
|
||||
return self.bypass_scale
|
||||
if random.random() < 0.1:
|
||||
# ensure we get grads if self.bypass_scale becomes out of range
|
||||
@ -407,7 +407,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
# return dropout rate for the dynamic modules (self_attn, pooling, convolution); this
|
||||
# starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable
|
||||
# at the beginning, by making the network focus on the feedforward modules.
|
||||
if torch.jit.is_scripting() or not self.training:
|
||||
if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing():
|
||||
return 0.0
|
||||
warmup_period = 2000.0
|
||||
initial_dropout_rate = 0.2
|
||||
@ -452,12 +452,12 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
dynamic_dropout = self.get_dynamic_dropout_rate()
|
||||
|
||||
# pooling module
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
src = src + self.pooling(src, key_padding_mask=src_key_padding_mask)
|
||||
elif random.random() >= dynamic_dropout:
|
||||
src = src + self.pooling(src, key_padding_mask=src_key_padding_mask)
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
src_att, attn_weights = self.self_attn(
|
||||
src,
|
||||
pos_emb=pos_emb,
|
||||
@ -658,7 +658,7 @@ class ZipformerEncoder(nn.Module):
|
||||
pos_emb = self.encoder_pos(src)
|
||||
output = src
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
layers_to_drop = []
|
||||
else:
|
||||
rnd_seed = src.numel() + random.randint(0, 1000)
|
||||
@ -667,7 +667,7 @@ class ZipformerEncoder(nn.Module):
|
||||
output = output * feature_mask
|
||||
|
||||
for i, mod in enumerate(self.layers):
|
||||
if not torch.jit.is_scripting():
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||
if i in layers_to_drop:
|
||||
continue
|
||||
output = mod(
|
||||
@ -864,7 +864,7 @@ class SimpleCombiner(torch.nn.Module):
|
||||
assert src1.shape[:-1] == src2.shape[:-1], (src1.shape, src2.shape)
|
||||
|
||||
weight1 = self.weight1
|
||||
if not torch.jit.is_scripting():
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||
if (
|
||||
self.training
|
||||
and random.random() < 0.25
|
||||
@ -1258,21 +1258,31 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
# the following .as_strided() expression converts the last axis of pos_weights from relative
|
||||
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
||||
# not, but let this code define which way round it is supposed to be.
|
||||
pos_weights = pos_weights.as_strided(
|
||||
(bsz, num_heads, seq_len, seq_len),
|
||||
(
|
||||
pos_weights.stride(0),
|
||||
pos_weights.stride(1),
|
||||
pos_weights.stride(2) - pos_weights.stride(3),
|
||||
pos_weights.stride(3),
|
||||
),
|
||||
storage_offset=pos_weights.stride(3) * (seq_len - 1),
|
||||
)
|
||||
if torch.jit.is_tracing():
|
||||
(batch_size, num_heads, time1, n) = pos_weights.shape
|
||||
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||
cols = torch.arange(seq_len)
|
||||
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
||||
indexes = rows + cols
|
||||
pos_weights = pos_weights.reshape(-1, n)
|
||||
pos_weights = torch.gather(pos_weights, dim=1, index=indexes)
|
||||
pos_weights = pos_weights.reshape(batch_size, num_heads, time1, seq_len)
|
||||
else:
|
||||
pos_weights = pos_weights.as_strided(
|
||||
(bsz, num_heads, seq_len, seq_len),
|
||||
(
|
||||
pos_weights.stride(0),
|
||||
pos_weights.stride(1),
|
||||
pos_weights.stride(2) - pos_weights.stride(3),
|
||||
pos_weights.stride(3),
|
||||
),
|
||||
storage_offset=pos_weights.stride(3) * (seq_len - 1),
|
||||
)
|
||||
|
||||
# caution: they are really scores at this point.
|
||||
attn_output_weights = torch.matmul(q, k) + pos_weights
|
||||
|
||||
if not torch.jit.is_scripting():
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||
if training and random.random() < 0.1:
|
||||
# This is a harder way of limiting the attention scores to not be too large.
|
||||
# It incurs a penalty if any of them has an absolute value greater than 50.0.
|
||||
@ -1383,7 +1393,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
# now v: (bsz * num_heads, seq_len, head_dim // 2)
|
||||
attn_output = torch.bmm(attn_weights, v)
|
||||
|
||||
if not torch.jit.is_scripting():
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||
if random.random() < 0.001 or __name__ == "__main__":
|
||||
self._print_attn_stats(attn_weights, attn_output)
|
||||
|
||||
@ -1458,7 +1468,10 @@ class PoolingModule(nn.Module):
|
||||
a Tensor of shape (1, N, C)
|
||||
"""
|
||||
if key_padding_mask is not None:
|
||||
pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T)
|
||||
if torch.jit.is_tracing():
|
||||
pooling_mask = (~key_padding_mask).to(x.dtype)
|
||||
else:
|
||||
pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T)
|
||||
pooling_mask = pooling_mask / pooling_mask.sum(dim=1, keepdim=True)
|
||||
pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1)
|
||||
# now pooling_mask: (T, N, 1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user