Add export to ONNX for Zipformer+CTC using blank skip (#861)

* Add export to ONNX for Zipformer+CTC using blank skip

---------

Co-authored-by: yifanyang <yifanyeung@yifanyangs-MacBook-Pro.local>
This commit is contained in:
Yifan Yang 2023-01-31 15:57:03 +08:00 committed by GitHub
parent e9019511eb
commit d8234e199c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 1188 additions and 20 deletions

View File

@ -72,14 +72,14 @@ Check ./pretrained.py for its usage.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
# You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp
git clone https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29
# You will find the pre-trained model in icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29/exp
"""
import argparse

View File

@ -0,0 +1,665 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang,
# Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
(1) Export to ONNX format
./pruned_transducer_stateless7_ctc_bs/export_onnx.py \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 30 \
--avg 13
It will generate the following files in the given `exp_dir`.
Check `onnx_check.py` for how to use them.
- encoder.onnx
- decoder.onnx
- joiner.onnx
- joiner_encoder_proj.onnx
- joiner_decoder_proj.onnx
- lconv.onnx
- frame_reducer.onnx
Please see ./onnx_pretrained.py for usage of the generated files
Check
https://github.com/k2-fsa/sherpa-onnx
for how to use the exported models outside of icefall.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29
# You will find the pre-trained model in icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29/exp
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7_ctc_bs/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--onnx",
type=str2bool,
default=True,
help="""If True, --jit is ignored and it exports the model
to onnx format. It will generate the following files:
- encoder.onnx
- decoder.onnx
- joiner.onnx
- joiner_encoder_proj.onnx
- joiner_decoder_proj.onnx
- lconv.onnx
- frame_reducer.onnx
Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
def export_encoder_model_onnx(
encoder_model: nn.Module,
encoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the given encoder model to ONNX format.
The exported model has two inputs:
- x, a tensor of shape (N, T, C); dtype is torch.float32
- x_lens, a tensor of shape (N,); dtype is torch.int64
and it has two outputs:
- encoder_out, a tensor of shape (N, T, C)
- encoder_out_lens, a tensor of shape (N,)
Note: The warmup argument is fixed to 1.
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
x = torch.zeros(15, 2000, 80, dtype=torch.float32)
x_lens = torch.tensor([2000] * 15, dtype=torch.int64)
# encoder_model = torch.jit.script(encoder_model)
# It throws the following error for the above statement
#
# RuntimeError: Exporting the operator __is_ to ONNX opset version
# 11 is not supported. Please feel free to request support or
# submit a pull request on PyTorch GitHub.
#
# I cannot find which statement causes the above error.
# torch.onnx.export() will use torch.jit.trace() internally, which
# works well for the current reworked model
torch.onnx.export(
encoder_model,
(x, x_lens),
encoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "x_lens"],
output_names=["encoder_out", "encoder_out_lens"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"encoder_out": {0: "N", 1: "T"},
"encoder_out_lens": {0: "N"},
},
)
logging.info(f"Saved to {encoder_filename}")
def export_decoder_model_onnx(
decoder_model: nn.Module,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.
The exported model has one input:
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
and has one output:
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
Note: The argument need_pad is fixed to False.
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
y = torch.zeros(15, decoder_model.context_size, dtype=torch.int64)
need_pad = False # Always False, so we can use torch.jit.trace() here
# Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script()
# in this case
torch.onnx.export(
decoder_model,
(y, need_pad),
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y", "need_pad"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
)
logging.info(f"Saved to {decoder_filename}")
def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported joiner model has two inputs:
- projected_encoder_out: a tensor of shape (N, joiner_dim)
- projected_decoder_out: a tensor of shape (N, joiner_dim)
and produces one output:
- logit: a tensor of shape (N, vocab_size)
The exported encoder_proj model has one input:
- encoder_out: a tensor of shape (N, encoder_out_dim)
and produces one output:
- projected_encoder_out: a tensor of shape (N, joiner_dim)
The exported decoder_proj model has one input:
- decoder_out: a tensor of shape (N, decoder_out_dim)
and produces one output:
- projected_decoder_out: a tensor of shape (N, joiner_dim)
"""
encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
joiner_dim = joiner_model.decoder_proj.weight.shape[0]
projected_encoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32)
project_input = False
# Note: It uses torch.jit.trace() internally
torch.onnx.export(
joiner_model,
(projected_encoder_out, projected_decoder_out, project_input),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"encoder_out",
"decoder_out",
"project_input",
],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
logging.info(f"Saved to {joiner_filename}")
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model.encoder_proj,
encoder_out,
encoder_proj_filename,
verbose=False,
opset_version=opset_version,
input_names=["encoder_out"],
output_names=["projected_encoder_out"],
dynamic_axes={
"encoder_out": {0: "N"},
"projected_encoder_out": {0: "N"},
},
)
logging.info(f"Saved to {encoder_proj_filename}")
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model.decoder_proj,
decoder_out,
decoder_proj_filename,
verbose=False,
opset_version=opset_version,
input_names=["decoder_out"],
output_names=["projected_decoder_out"],
dynamic_axes={
"decoder_out": {0: "N"},
"projected_decoder_out": {0: "N"},
},
)
logging.info(f"Saved to {decoder_proj_filename}")
def export_lconv_onnx(
lconv: nn.Module,
lconv_filename: str,
opset_version: int = 11,
) -> None:
"""Export the lconv to ONNX format.
The exported lconv has two inputs:
- lconv_input: a tensor of shape (N, T, C)
- src_key_padding_mask: a tensor of shape (N, T)
and has one output:
- lconv_out: a tensor of shape (N, T, C)
Args:
lconv:
The lconv to be exported.
lconv_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
lconv_input = torch.zeros(15, 498, 384, dtype=torch.float32)
src_key_padding_mask = torch.zeros(15, 498, dtype=torch.bool)
torch.onnx.export(
lconv,
(lconv_input, src_key_padding_mask),
lconv_filename,
verbose=False,
opset_version=opset_version,
input_names=["lconv_input", "src_key_padding_mask"],
output_names=["lconv_out"],
dynamic_axes={
"lconv_input": {0: "N", 1: "T"},
"src_key_padding_mask": {0: "N", 1: "T"},
"lconv_out": {0: "N", 1: "T"},
},
)
logging.info(f"Saved to {lconv_filename}")
def export_frame_reducer_onnx(
frame_reducer: nn.Module,
frame_reducer_filename: str,
opset_version: int = 11,
) -> None:
"""Export the frame_reducer to ONNX format.
The exported frame_reducer has four inputs:
- x: a tensor of shape (N, T, C)
- x_lens: a tensor of shape (N, T)
- ctc_output: a tensor of shape (N, T, vocab_size)
- blank_id: an int, always 0
and has two outputs:
- x_fr: a tensor of shape (N, T, C)
- x_lens_fr: a tensor of shape (N, T)
Args:
frame_reducer:
The frame_reducer to be exported.
frame_reducer_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
x = torch.zeros(15, 498, 384, dtype=torch.float32)
x_lens = torch.tensor([498] * 15, dtype=torch.int64)
ctc_output = torch.randn(15, 498, 500, dtype=torch.float32)
torch.onnx.export(
frame_reducer,
(x, x_lens, ctc_output),
frame_reducer_filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "x_lens", "ctc_output"],
output_names=["out", "out_lens"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"ctc_output": {0: "N", 1: "T"},
"out": {0: "N", 1: "T"},
"out_lens": {0: "N"},
},
)
logging.info(f"Saved to {frame_reducer_filename}")
def export_ctc_output_onnx(
ctc_output: nn.Module,
ctc_output_filename: str,
opset_version: int = 11,
) -> None:
"""Export the frame_reducer to ONNX format.
The exported frame_reducer has one inputs:
- encoder_out: a tensor of shape (N, T, C)
and has one output:
- ctc_output: a tensor of shape (N, T, vocab_size)
Args:
ctc_output:
The ctc_output to be exported.
ctc_output_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
encoder_out = torch.zeros(15, 498, 384, dtype=torch.float32)
torch.onnx.export(
ctc_output,
(encoder_out),
ctc_output_filename,
verbose=False,
opset_version=opset_version,
input_names=["encoder_out"],
output_names=["ctc_output"],
dynamic_axes={
"encoder_out": {0: "N", 1: "T"},
"ctc_output": {0: "N", 1: "T"},
},
)
logging.info(f"Saved to {ctc_output_filename}")
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
convert_scaled_to_non_scaled(model, inplace=True)
opset_version = 13
logging.info("Exporting to onnx format")
encoder_filename = params.exp_dir / "encoder.onnx"
export_encoder_model_onnx(
model.encoder,
encoder_filename,
opset_version=opset_version,
)
decoder_filename = params.exp_dir / "decoder.onnx"
export_decoder_model_onnx(
model.decoder,
decoder_filename,
opset_version=opset_version,
)
joiner_filename = params.exp_dir / "joiner.onnx"
export_joiner_model_onnx(
model.joiner,
joiner_filename,
opset_version=opset_version,
)
lconv_filename = params.exp_dir / "lconv.onnx"
export_lconv_onnx(
model.lconv,
lconv_filename,
opset_version=opset_version,
)
frame_reducer_filename = params.exp_dir / "frame_reducer.onnx"
export_frame_reducer_onnx(
model.frame_reducer,
frame_reducer_filename,
opset_version=opset_version,
)
ctc_output_filename = params.exp_dir / "ctc_output.onnx"
export_ctc_output_onnx(
model.ctc_output,
ctc_output_filename,
opset_version=opset_version,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -22,7 +22,8 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from icefall.utils import make_pad_mask
@ -43,7 +44,6 @@ class FrameReducer(nn.Module):
x: torch.Tensor,
x_lens: torch.Tensor,
ctc_output: torch.Tensor,
blank_id: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
@ -54,26 +54,68 @@ class FrameReducer(nn.Module):
`x` before padding.
ctc_output:
The CTC output with shape [N, T, vocab_size].
blank_id:
The ID of the blank symbol.
Returns:
x_fr:
out:
The frame reduced encoder output with shape [N, T', C].
x_lens_fr:
out_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x_fr` before padding.
`out` before padding.
"""
N, T, C = x.size()
padding_mask = make_pad_mask(x_lens)
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
non_blank_mask = (ctc_output[:, :, 0] < math.log(0.9)) * (~padding_mask)
frames_list: List[torch.Tensor] = []
lens_list: List[int] = []
for i in range(x.shape[0]):
frames = x[i][non_blank_mask[i]]
frames_list.append(frames)
lens_list.append(frames.shape[0])
x_fr = pad_sequence(frames_list, batch_first=True)
x_lens_fr = torch.tensor(lens_list).to(device=x.device)
out_lens = non_blank_mask.sum(dim=1)
max_len = out_lens.max()
pad_lens_list = torch.full_like(out_lens, max_len.item()) - out_lens
max_pad_len = pad_lens_list.max()
return x_fr, x_lens_fr
out = F.pad(x, (0, 0, 0, max_pad_len))
valid_pad_mask = ~make_pad_mask(pad_lens_list)
total_valid_mask = torch.concat([non_blank_mask, valid_pad_mask], dim=1)
out = out[total_valid_mask].reshape(N, -1, C)
return out.to(device=x.device), out_lens.to(device=x.device)
if __name__ == "__main__":
import time
from torch.nn.utils.rnn import pad_sequence
test_times = 10000
frame_reducer = FrameReducer()
# non zero case
x = torch.ones(15, 498, 384, dtype=torch.float32)
x_lens = torch.tensor([498] * 15, dtype=torch.int64)
ctc_output = torch.log(torch.randn(15, 498, 500, dtype=torch.float32))
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output)
avg_time = 0
for i in range(test_times):
delta_time = time.time()
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output)
delta_time = time.time() - delta_time
avg_time += delta_time
print(x_fr.shape)
print(x_lens_fr)
print(avg_time / test_times)
# all zero case
x = torch.zeros(15, 498, 384, dtype=torch.float32)
x_lens = torch.tensor([498] * 15, dtype=torch.int64)
ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32)
avg_time = 0
for i in range(test_times):
delta_time = time.time()
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output)
delta_time = time.time() - delta_time
avg_time += delta_time
print(x_fr.shape)
print(x_lens_fr)
print(avg_time / test_times)

View File

@ -0,0 +1,461 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang,
# Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads ONNX models and uses them to decode waves.
You can use the following command to get the exported models:
./pruned_transducer_stateless7_ctc_bs/export_onnx.py \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 30 \
--avg 13
Usage of this script:
./pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py \
--encoder-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/encoder.onnx \
--decoder-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/decoder.onnx \
--joiner-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/joiner.onnx \
--joiner-encoder-proj-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/joiner_encoder_proj.onnx \
--joiner-decoder-proj-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/joiner_decoder_proj.onnx \
--lconv-filename ./pruned_transducer_stateless7_ctc_bs/exp/lconv.onnx \
--frame-reducer-filename ./pruned_transducer_stateless7_ctc_bs/exp/frame_reducer.onnx \
--ctc-output-filename ./pruned_transducer_stateless7_ctc_bs/exp/ctc_output.onnx \
--bpe-model ./data/lang_bpe_500/bpe.model \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
import logging
import math
from typing import List
import kaldifeat
import numpy as np
import onnxruntime as ort
import sentencepiece as spm
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
from icefall.utils import make_pad_mask
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder onnx model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder onnx model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner onnx model. ",
)
parser.add_argument(
"--joiner-encoder-proj-model-filename",
type=str,
required=True,
help="Path to the joiner encoder_proj onnx model. ",
)
parser.add_argument(
"--joiner-decoder-proj-model-filename",
type=str,
required=True,
help="Path to the joiner decoder_proj onnx model. ",
)
parser.add_argument(
"--lconv-filename",
type=str,
required=True,
help="Path to the lconv onnx model. ",
)
parser.add_argument(
"--frame-reducer-filename",
type=str,
required=True,
help="Path to the frame reducer onnx model. ",
)
parser.add_argument(
"--ctc-output-filename",
type=str,
required=True,
help="Path to the ctc_output onnx model. ",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="Context size of the decoder model",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
def greedy_search(
decoder: ort.InferenceSession,
joiner: ort.InferenceSession,
joiner_encoder_proj: ort.InferenceSession,
joiner_decoder_proj: ort.InferenceSession,
encoder_out: np.ndarray,
encoder_out_lens: np.ndarray,
context_size: int,
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
decoder:
The decoder model.
joiner:
The joiner model.
joiner_encoder_proj:
The joiner encoder projection model.
joiner_decoder_proj:
The joiner decoder projection model.
encoder_out:
A 3-D tensor of shape (N, T, C)
encoder_out_lens:
A 1-D tensor of shape (N,).
context_size:
The context size of the decoder model.
Returns:
Return the decoded results for each utterance.
"""
encoder_out = torch.from_numpy(encoder_out)
encoder_out_lens = torch.from_numpy(encoder_out_lens)
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
projected_encoder_out = joiner_encoder_proj.run(
[joiner_encoder_proj.get_outputs()[0].name],
{joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()},
)[0]
blank_id = 0 # hard-code to 0
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
hyps = [[blank_id] * context_size for _ in range(N)]
decoder_input_nodes = decoder.get_inputs()
decoder_output_nodes = decoder.get_outputs()
joiner_input_nodes = joiner.get_inputs()
joiner_output_nodes = joiner.get_outputs()
decoder_input = torch.tensor(
hyps,
dtype=torch.int64,
) # (N, context_size)
decoder_out = decoder.run(
[decoder_output_nodes[0].name],
{
decoder_input_nodes[0].name: decoder_input.numpy(),
},
)[0].squeeze(1)
projected_decoder_out = joiner_decoder_proj.run(
[joiner_decoder_proj.get_outputs()[0].name],
{joiner_decoder_proj.get_inputs()[0].name: decoder_out},
)[0]
projected_decoder_out = torch.from_numpy(projected_decoder_out)
offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = projected_encoder_out[start:end]
# current_encoder_out's shape: (batch_size, encoder_out_dim)
offset = end
projected_decoder_out = projected_decoder_out[:batch_size]
logits = joiner.run(
[joiner_output_nodes[0].name],
{
joiner_input_nodes[0].name: np.expand_dims(
np.expand_dims(current_encoder_out, axis=1), axis=1
),
joiner_input_nodes[1]
.name: projected_decoder_out.unsqueeze(1)
.unsqueeze(1)
.numpy(),
},
)[0]
logits = torch.from_numpy(logits).squeeze(1).squeeze(1)
# logits'shape (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor(
decoder_input,
dtype=torch.int64,
)
decoder_out = decoder.run(
[decoder_output_nodes[0].name],
{
decoder_input_nodes[0].name: decoder_input.numpy(),
},
)[0].squeeze(1)
projected_decoder_out = joiner_decoder_proj.run(
[joiner_decoder_proj.get_outputs()[0].name],
{joiner_decoder_proj.get_inputs()[0].name: decoder_out},
)[0]
projected_decoder_out = torch.from_numpy(projected_decoder_out)
sorted_ans = [h[context_size:] for h in hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
encoder = ort.InferenceSession(
args.encoder_model_filename,
sess_options=session_opts,
)
decoder = ort.InferenceSession(
args.decoder_model_filename,
sess_options=session_opts,
)
joiner = ort.InferenceSession(
args.joiner_model_filename,
sess_options=session_opts,
)
joiner_encoder_proj = ort.InferenceSession(
args.joiner_encoder_proj_model_filename,
sess_options=session_opts,
)
joiner_decoder_proj = ort.InferenceSession(
args.joiner_decoder_proj_model_filename,
sess_options=session_opts,
)
lconv = ort.InferenceSession(
args.lconv_filename,
sess_options=session_opts,
)
frame_reducer = ort.InferenceSession(
args.frame_reducer_filename,
sess_options=session_opts,
)
ctc_output = ort.InferenceSession(
args.ctc_output_filename,
sess_options=session_opts,
)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files,
expected_sample_rate=args.sample_rate,
)
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10),
)
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
encoder_input_nodes = encoder.get_inputs()
encoder_out_nodes = encoder.get_outputs()
encoder_out, encoder_out_lens = encoder.run(
[encoder_out_nodes[0].name, encoder_out_nodes[1].name],
{
encoder_input_nodes[0].name: features.numpy(),
encoder_input_nodes[1].name: feature_lengths.numpy(),
},
)
ctc_output_input_nodes = ctc_output.get_inputs()
ctc_output_out_nodes = ctc_output.get_outputs()
ctc_out = ctc_output.run(
[ctc_output_out_nodes[0].name],
{
ctc_output_input_nodes[0].name: encoder_out,
},
)[0]
lconv_input_nodes = lconv.get_inputs()
lconv_out_nodes = lconv.get_outputs()
encoder_out = lconv.run(
[lconv_out_nodes[0].name],
{
lconv_input_nodes[0].name: encoder_out,
lconv_input_nodes[1]
.name: make_pad_mask(torch.from_numpy(encoder_out_lens))
.numpy(),
},
)[0]
frame_reducer_input_nodes = frame_reducer.get_inputs()
frame_reducer_out_nodes = frame_reducer.get_outputs()
encoder_out_fr, encoder_out_lens_fr = frame_reducer.run(
[frame_reducer_out_nodes[0].name, frame_reducer_out_nodes[1].name],
{
frame_reducer_input_nodes[0].name: encoder_out,
frame_reducer_input_nodes[1].name: encoder_out_lens,
frame_reducer_input_nodes[2].name: ctc_out,
},
)
hyps = greedy_search(
decoder=decoder,
joiner=joiner,
joiner_encoder_proj=joiner_encoder_proj,
joiner_decoder_proj=joiner_decoder_proj,
encoder_out=encoder_out_fr,
encoder_out_lens=encoder_out_lens_fr,
context_size=args.context_size,
)
s = "\n"
for filename, hyp in zip(args.sound_files, hyps):
words = sp.decode(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()