mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
add onnx export for stateless2 (#1086)
This commit is contained in:
parent
ea8b15309f
commit
1df71a6b38
517
egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py
Executable file
517
egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py
Executable file
@ -0,0 +1,517 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script exports a transducer model from PyTorch to ONNX.
|
||||||
|
|
||||||
|
We use the pre-trained model from
|
||||||
|
https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2
|
||||||
|
as an example to show how to use this file.
|
||||||
|
|
||||||
|
1. Download the pre-trained model
|
||||||
|
|
||||||
|
cd egs/wenetspeech/ASR
|
||||||
|
|
||||||
|
repo_url=icefall_asr_wenetspeech_pruned_transducer_stateless2
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "data/lang_char/Linv.pt"
|
||||||
|
git lfs pull --include "exp/pretrained_epoch_10_avg_2.pt"
|
||||||
|
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained_epoch_10_avg_2.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
2. Export the model to ONNX
|
||||||
|
|
||||||
|
./pruned_transducer_stateless2/export-onnx.py \
|
||||||
|
--lang-dir $repo/data/lang_char \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir $repo/exp
|
||||||
|
|
||||||
|
It will generate the following 3 files inside $repo/exp:
|
||||||
|
|
||||||
|
- encoder-epoch-99-avg-1.onnx
|
||||||
|
- decoder-epoch-99-avg-1.onnx
|
||||||
|
- joiner-epoch-99-avg-1.onnx
|
||||||
|
|
||||||
|
See ./onnx_pretrained.py for how to
|
||||||
|
use the exported ONNX models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from conformer import Conformer
|
||||||
|
from decoder import Decoder
|
||||||
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.utils import setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=28,
|
||||||
|
help="""It specifies the checkpoint to use for averaging.
|
||||||
|
Note: Epoch counts from 0.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="pruned_transducer_stateless5/exp",
|
||||||
|
help="""It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_char",
|
||||||
|
help="The lang dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||||
|
"""Add meta data to an ONNX model. It is changed in-place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Filename of the ONNX model to be changed.
|
||||||
|
meta_data:
|
||||||
|
Key-value pairs.
|
||||||
|
"""
|
||||||
|
model = onnx.load(filename)
|
||||||
|
for key, value in meta_data.items():
|
||||||
|
meta = model.metadata_props.add()
|
||||||
|
meta.key = key
|
||||||
|
meta.value = value
|
||||||
|
|
||||||
|
onnx.save(model, filename)
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxEncoder(nn.Module):
|
||||||
|
"""A wrapper for Conformer and the encoder_proj from the joiner"""
|
||||||
|
|
||||||
|
def __init__(self, encoder: Conformer, encoder_proj: nn.Linear):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder:
|
||||||
|
A Conformer encoder.
|
||||||
|
encoder_proj:
|
||||||
|
The projection layer for encoder from the joiner.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = encoder
|
||||||
|
self.encoder_proj = encoder_proj
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_lens: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Please see the help information of Conformer.forward
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A 3-D tensor of shape (N, T, C)
|
||||||
|
x_lens:
|
||||||
|
A 1-D tensor of shape (N,). Its dtype is torch.int64
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing:
|
||||||
|
- encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
|
||||||
|
- encoder_out_lens, A 1-D tensor of shape (N,)
|
||||||
|
"""
|
||||||
|
encoder_out, encoder_out_lens = self.encoder(x, x_lens)
|
||||||
|
|
||||||
|
encoder_out = self.encoder_proj(encoder_out)
|
||||||
|
# Now encoder_out is of shape (N, T, joiner_dim)
|
||||||
|
|
||||||
|
return encoder_out, encoder_out_lens
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxDecoder(nn.Module):
|
||||||
|
"""A wrapper for Decoder and the decoder_proj from the joiner"""
|
||||||
|
|
||||||
|
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
|
||||||
|
super().__init__()
|
||||||
|
self.decoder = decoder
|
||||||
|
self.decoder_proj = decoder_proj
|
||||||
|
|
||||||
|
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
y:
|
||||||
|
A 2-D tensor of shape (N, context_size).
|
||||||
|
Returns
|
||||||
|
Return a 2-D tensor of shape (N, joiner_dim)
|
||||||
|
"""
|
||||||
|
need_pad = False
|
||||||
|
decoder_output = self.decoder(y, need_pad=need_pad)
|
||||||
|
decoder_output = decoder_output.squeeze(1)
|
||||||
|
output = self.decoder_proj(decoder_output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxJoiner(nn.Module):
|
||||||
|
"""A wrapper for the joiner"""
|
||||||
|
|
||||||
|
def __init__(self, output_linear: nn.Linear):
|
||||||
|
super().__init__()
|
||||||
|
self.output_linear = output_linear
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
decoder_out: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder_out:
|
||||||
|
A 2-D tensor of shape (N, joiner_dim)
|
||||||
|
decoder_out:
|
||||||
|
A 2-D tensor of shape (N, joiner_dim)
|
||||||
|
Returns:
|
||||||
|
Return a 2-D tensor of shape (N, vocab_size)
|
||||||
|
"""
|
||||||
|
logit = encoder_out + decoder_out
|
||||||
|
logit = self.output_linear(torch.tanh(logit))
|
||||||
|
return logit
|
||||||
|
|
||||||
|
|
||||||
|
def export_encoder_model_onnx(
|
||||||
|
encoder_model: OnnxEncoder,
|
||||||
|
encoder_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given encoder model to ONNX format.
|
||||||
|
The exported model has two inputs:
|
||||||
|
|
||||||
|
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
||||||
|
- x_lens, a tensor of shape (N,); dtype is torch.int64
|
||||||
|
|
||||||
|
and it has two outputs:
|
||||||
|
|
||||||
|
- encoder_out, a tensor of shape (N, T', joiner_dim)
|
||||||
|
- encoder_out_lens, a tensor of shape (N,)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_model:
|
||||||
|
The input encoder model
|
||||||
|
encoder_filename:
|
||||||
|
The filename to save the exported ONNX model.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||||
|
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
encoder_model,
|
||||||
|
(x, x_lens),
|
||||||
|
encoder_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["x", "x_lens"],
|
||||||
|
output_names=["encoder_out", "encoder_out_lens"],
|
||||||
|
dynamic_axes={
|
||||||
|
"x": {0: "N", 1: "T"},
|
||||||
|
"x_lens": {0: "N"},
|
||||||
|
"encoder_out": {0: "N", 1: "T"},
|
||||||
|
"encoder_out_lens": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
meta_data = {
|
||||||
|
"model_type": "conformer",
|
||||||
|
"version": "1",
|
||||||
|
"model_author": "k2-fsa",
|
||||||
|
"comment": "stateless5",
|
||||||
|
}
|
||||||
|
logging.info(f"meta_data: {meta_data}")
|
||||||
|
|
||||||
|
add_meta_data(filename=encoder_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
def export_decoder_model_onnx(
|
||||||
|
decoder_model: OnnxDecoder,
|
||||||
|
decoder_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the decoder model to ONNX format.
|
||||||
|
|
||||||
|
The exported model has one input:
|
||||||
|
|
||||||
|
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||||
|
|
||||||
|
and has one output:
|
||||||
|
|
||||||
|
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder_model:
|
||||||
|
The decoder model to be exported.
|
||||||
|
decoder_filename:
|
||||||
|
Filename to save the exported ONNX model.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
context_size = decoder_model.decoder.context_size
|
||||||
|
vocab_size = decoder_model.decoder.vocab_size
|
||||||
|
|
||||||
|
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||||
|
torch.onnx.export(
|
||||||
|
decoder_model,
|
||||||
|
y,
|
||||||
|
decoder_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["y"],
|
||||||
|
output_names=["decoder_out"],
|
||||||
|
dynamic_axes={
|
||||||
|
"y": {0: "N"},
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
meta_data = {
|
||||||
|
"context_size": str(context_size),
|
||||||
|
"vocab_size": str(vocab_size),
|
||||||
|
}
|
||||||
|
add_meta_data(filename=decoder_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
def export_joiner_model_onnx(
|
||||||
|
joiner_model: nn.Module,
|
||||||
|
joiner_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the joiner model to ONNX format.
|
||||||
|
The exported joiner model has two inputs:
|
||||||
|
|
||||||
|
- encoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
- decoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
|
||||||
|
and produces one output:
|
||||||
|
|
||||||
|
- logit: a tensor of shape (N, vocab_size)
|
||||||
|
"""
|
||||||
|
joiner_dim = joiner_model.output_linear.weight.shape[1]
|
||||||
|
logging.info(f"joiner dim: {joiner_dim}")
|
||||||
|
|
||||||
|
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||||
|
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
joiner_model,
|
||||||
|
(projected_encoder_out, projected_decoder_out),
|
||||||
|
joiner_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=[
|
||||||
|
"encoder_out",
|
||||||
|
"decoder_out",
|
||||||
|
],
|
||||||
|
output_names=["logit"],
|
||||||
|
dynamic_axes={
|
||||||
|
"encoder_out": {0: "N"},
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
"logit": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
meta_data = {
|
||||||
|
"joiner_dim": str(joiner_dim),
|
||||||
|
}
|
||||||
|
add_meta_data(filename=joiner_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
setup_logger(f"{params.exp_dir}/log-export/log-export-onnx")
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
params.blank_id = 0
|
||||||
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
if params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if start >= 0:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
|
||||||
|
model.to("cpu")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
|
|
||||||
|
encoder = OnnxEncoder(
|
||||||
|
encoder=model.encoder,
|
||||||
|
encoder_proj=model.joiner.encoder_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder = OnnxDecoder(
|
||||||
|
decoder=model.decoder,
|
||||||
|
decoder_proj=model.joiner.decoder_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
|
||||||
|
|
||||||
|
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
|
||||||
|
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
|
||||||
|
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
|
||||||
|
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
|
||||||
|
logging.info(f"encoder parameters: {encoder_num_param}")
|
||||||
|
logging.info(f"decoder parameters: {decoder_num_param}")
|
||||||
|
logging.info(f"joiner parameters: {joiner_num_param}")
|
||||||
|
logging.info(f"total parameters: {total_num_param}")
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
suffix = f"iter-{params.iter}"
|
||||||
|
else:
|
||||||
|
suffix = f"epoch-{params.epoch}"
|
||||||
|
|
||||||
|
suffix += f"-avg-{params.avg}"
|
||||||
|
|
||||||
|
opset_version = 13
|
||||||
|
|
||||||
|
logging.info("Exporting encoder")
|
||||||
|
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
|
||||||
|
export_encoder_model_onnx(
|
||||||
|
encoder,
|
||||||
|
encoder_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported encoder to {encoder_filename}")
|
||||||
|
|
||||||
|
logging.info("Exporting decoder")
|
||||||
|
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
|
||||||
|
export_decoder_model_onnx(
|
||||||
|
decoder,
|
||||||
|
decoder_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported decoder to {decoder_filename}")
|
||||||
|
|
||||||
|
logging.info("Exporting joiner")
|
||||||
|
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
|
||||||
|
export_joiner_model_onnx(
|
||||||
|
joiner,
|
||||||
|
joiner_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported joiner to {joiner_filename}")
|
||||||
|
|
||||||
|
# Generate int8 quantization models
|
||||||
|
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||||
|
|
||||||
|
logging.info("Generate int8 quantization models")
|
||||||
|
|
||||||
|
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=encoder_filename,
|
||||||
|
model_output=encoder_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=decoder_filename,
|
||||||
|
model_output=decoder_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=joiner_filename,
|
||||||
|
model_output=joiner_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
main()
|
@ -59,23 +59,7 @@ It will generate the following files:
|
|||||||
|
|
||||||
Check ./jit_pretrained.py for usage.
|
Check ./jit_pretrained.py for usage.
|
||||||
|
|
||||||
(3) Export to ONNX format
|
(3) Export `model.state_dict()`
|
||||||
|
|
||||||
./pruned_transducer_stateless2/export.py \
|
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
|
||||||
--lang-dir data/lang_char \
|
|
||||||
--epoch 10 \
|
|
||||||
--avg 2 \
|
|
||||||
--onnx 1
|
|
||||||
|
|
||||||
Refer to ./onnx_check.py and ./onnx_pretrained.py
|
|
||||||
for usage.
|
|
||||||
|
|
||||||
Check
|
|
||||||
https://github.com/k2-fsa/sherpa-onnx
|
|
||||||
for how to use the exported models outside of icefall.
|
|
||||||
|
|
||||||
(4) Export `model.state_dict()`
|
|
||||||
|
|
||||||
./pruned_transducer_stateless2/export.py \
|
./pruned_transducer_stateless2/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
@ -184,23 +168,6 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--onnx",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""If True, --jit is ignored and it exports the model
|
|
||||||
to onnx format. It will generate the following files:
|
|
||||||
|
|
||||||
- encoder.onnx
|
|
||||||
- decoder.onnx
|
|
||||||
- joiner.onnx
|
|
||||||
- joiner_encoder_proj.onnx
|
|
||||||
- joiner_decoder_proj.onnx
|
|
||||||
|
|
||||||
Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -333,206 +300,6 @@ def export_joiner_model_jit_trace(
|
|||||||
logging.info(f"Saved to {joiner_filename}")
|
logging.info(f"Saved to {joiner_filename}")
|
||||||
|
|
||||||
|
|
||||||
def export_encoder_model_onnx(
|
|
||||||
encoder_model: nn.Module,
|
|
||||||
encoder_filename: str,
|
|
||||||
opset_version: int = 11,
|
|
||||||
) -> None:
|
|
||||||
"""Export the given encoder model to ONNX format.
|
|
||||||
The exported model has two inputs:
|
|
||||||
|
|
||||||
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
|
||||||
- x_lens, a tensor of shape (N,); dtype is torch.int64
|
|
||||||
|
|
||||||
and it has two outputs:
|
|
||||||
|
|
||||||
- encoder_out, a tensor of shape (N, T, C)
|
|
||||||
- encoder_out_lens, a tensor of shape (N,)
|
|
||||||
|
|
||||||
Note: The warmup argument is fixed to 1.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
encoder_model:
|
|
||||||
The input encoder model
|
|
||||||
encoder_filename:
|
|
||||||
The filename to save the exported ONNX model.
|
|
||||||
opset_version:
|
|
||||||
The opset version to use.
|
|
||||||
"""
|
|
||||||
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
|
||||||
x_lens = torch.tensor([100], dtype=torch.int64)
|
|
||||||
|
|
||||||
# encoder_model = torch.jit.script(encoder_model)
|
|
||||||
# It throws the following error for the above statement
|
|
||||||
#
|
|
||||||
# RuntimeError: Exporting the operator __is_ to ONNX opset version
|
|
||||||
# 11 is not supported. Please feel free to request support or
|
|
||||||
# submit a pull request on PyTorch GitHub.
|
|
||||||
#
|
|
||||||
# I cannot find which statement causes the above error.
|
|
||||||
# torch.onnx.export() will use torch.jit.trace() internally, which
|
|
||||||
# works well for the current reworked model
|
|
||||||
warmup = 1.0
|
|
||||||
torch.onnx.export(
|
|
||||||
encoder_model,
|
|
||||||
(x, x_lens, warmup),
|
|
||||||
encoder_filename,
|
|
||||||
verbose=False,
|
|
||||||
opset_version=opset_version,
|
|
||||||
input_names=["x", "x_lens", "warmup"],
|
|
||||||
output_names=["encoder_out", "encoder_out_lens"],
|
|
||||||
dynamic_axes={
|
|
||||||
"x": {0: "N", 1: "T"},
|
|
||||||
"x_lens": {0: "N"},
|
|
||||||
"encoder_out": {0: "N", 1: "T"},
|
|
||||||
"encoder_out_lens": {0: "N"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
logging.info(f"Saved to {encoder_filename}")
|
|
||||||
|
|
||||||
|
|
||||||
def export_decoder_model_onnx(
|
|
||||||
decoder_model: nn.Module,
|
|
||||||
decoder_filename: str,
|
|
||||||
opset_version: int = 11,
|
|
||||||
) -> None:
|
|
||||||
"""Export the decoder model to ONNX format.
|
|
||||||
|
|
||||||
The exported model has one input:
|
|
||||||
|
|
||||||
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
|
||||||
|
|
||||||
and has one output:
|
|
||||||
|
|
||||||
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
|
|
||||||
|
|
||||||
Note: The argument need_pad is fixed to False.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
decoder_model:
|
|
||||||
The decoder model to be exported.
|
|
||||||
decoder_filename:
|
|
||||||
Filename to save the exported ONNX model.
|
|
||||||
opset_version:
|
|
||||||
The opset version to use.
|
|
||||||
"""
|
|
||||||
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
|
||||||
need_pad = False # Always False, so we can use torch.jit.trace() here
|
|
||||||
# Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script()
|
|
||||||
# in this case
|
|
||||||
torch.onnx.export(
|
|
||||||
decoder_model,
|
|
||||||
(y, need_pad),
|
|
||||||
decoder_filename,
|
|
||||||
verbose=False,
|
|
||||||
opset_version=opset_version,
|
|
||||||
input_names=["y", "need_pad"],
|
|
||||||
output_names=["decoder_out"],
|
|
||||||
dynamic_axes={
|
|
||||||
"y": {0: "N"},
|
|
||||||
"decoder_out": {0: "N"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
logging.info(f"Saved to {decoder_filename}")
|
|
||||||
|
|
||||||
|
|
||||||
def export_joiner_model_onnx(
|
|
||||||
joiner_model: nn.Module,
|
|
||||||
joiner_filename: str,
|
|
||||||
opset_version: int = 11,
|
|
||||||
) -> None:
|
|
||||||
"""Export the joiner model to ONNX format.
|
|
||||||
The exported joiner model has two inputs:
|
|
||||||
|
|
||||||
- projected_encoder_out: a tensor of shape (N, joiner_dim)
|
|
||||||
- projected_decoder_out: a tensor of shape (N, joiner_dim)
|
|
||||||
|
|
||||||
and produces one output:
|
|
||||||
|
|
||||||
- logit: a tensor of shape (N, vocab_size)
|
|
||||||
|
|
||||||
The exported encoder_proj model has one input:
|
|
||||||
|
|
||||||
- encoder_out: a tensor of shape (N, encoder_out_dim)
|
|
||||||
|
|
||||||
and produces one output:
|
|
||||||
|
|
||||||
- projected_encoder_out: a tensor of shape (N, joiner_dim)
|
|
||||||
|
|
||||||
The exported decoder_proj model has one input:
|
|
||||||
|
|
||||||
- decoder_out: a tensor of shape (N, decoder_out_dim)
|
|
||||||
|
|
||||||
and produces one output:
|
|
||||||
|
|
||||||
- projected_decoder_out: a tensor of shape (N, joiner_dim)
|
|
||||||
"""
|
|
||||||
encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
|
|
||||||
|
|
||||||
decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
|
|
||||||
|
|
||||||
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
|
||||||
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
|
||||||
joiner_dim = joiner_model.decoder_proj.weight.shape[0]
|
|
||||||
|
|
||||||
projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
|
|
||||||
projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
|
|
||||||
|
|
||||||
project_input = False
|
|
||||||
# Note: It uses torch.jit.trace() internally
|
|
||||||
torch.onnx.export(
|
|
||||||
joiner_model,
|
|
||||||
(projected_encoder_out, projected_decoder_out, project_input),
|
|
||||||
joiner_filename,
|
|
||||||
verbose=False,
|
|
||||||
opset_version=opset_version,
|
|
||||||
input_names=[
|
|
||||||
"projected_encoder_out",
|
|
||||||
"projected_decoder_out",
|
|
||||||
"project_input",
|
|
||||||
],
|
|
||||||
output_names=["logit"],
|
|
||||||
dynamic_axes={
|
|
||||||
"projected_encoder_out": {0: "N"},
|
|
||||||
"projected_decoder_out": {0: "N"},
|
|
||||||
"logit": {0: "N"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
logging.info(f"Saved to {joiner_filename}")
|
|
||||||
|
|
||||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
|
||||||
torch.onnx.export(
|
|
||||||
joiner_model.encoder_proj,
|
|
||||||
encoder_out,
|
|
||||||
encoder_proj_filename,
|
|
||||||
verbose=False,
|
|
||||||
opset_version=opset_version,
|
|
||||||
input_names=["encoder_out"],
|
|
||||||
output_names=["projected_encoder_out"],
|
|
||||||
dynamic_axes={
|
|
||||||
"encoder_out": {0: "N"},
|
|
||||||
"projected_encoder_out": {0: "N"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
logging.info(f"Saved to {encoder_proj_filename}")
|
|
||||||
|
|
||||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
|
||||||
torch.onnx.export(
|
|
||||||
joiner_model.decoder_proj,
|
|
||||||
decoder_out,
|
|
||||||
decoder_proj_filename,
|
|
||||||
verbose=False,
|
|
||||||
opset_version=opset_version,
|
|
||||||
input_names=["decoder_out"],
|
|
||||||
output_names=["projected_decoder_out"],
|
|
||||||
dynamic_axes={
|
|
||||||
"decoder_out": {0: "N"},
|
|
||||||
"projected_decoder_out": {0: "N"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
logging.info(f"Saved to {decoder_proj_filename}")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
@ -573,31 +340,7 @@ def main():
|
|||||||
model.to("cpu")
|
model.to("cpu")
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if params.onnx is True:
|
if params.jit:
|
||||||
convert_scaled_to_non_scaled(model, inplace=True)
|
|
||||||
opset_version = 11
|
|
||||||
logging.info("Exporting to onnx format")
|
|
||||||
encoder_filename = params.exp_dir / "encoder.onnx"
|
|
||||||
export_encoder_model_onnx(
|
|
||||||
model.encoder,
|
|
||||||
encoder_filename,
|
|
||||||
opset_version=opset_version,
|
|
||||||
)
|
|
||||||
|
|
||||||
decoder_filename = params.exp_dir / "decoder.onnx"
|
|
||||||
export_decoder_model_onnx(
|
|
||||||
model.decoder,
|
|
||||||
decoder_filename,
|
|
||||||
opset_version=opset_version,
|
|
||||||
)
|
|
||||||
|
|
||||||
joiner_filename = params.exp_dir / "joiner.onnx"
|
|
||||||
export_joiner_model_onnx(
|
|
||||||
model.joiner,
|
|
||||||
joiner_filename,
|
|
||||||
opset_version=opset_version,
|
|
||||||
)
|
|
||||||
elif params.jit:
|
|
||||||
convert_scaled_to_non_scaled(model, inplace=True)
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
logging.info("Using torch.jit.script")
|
logging.info("Using torch.jit.script")
|
||||||
# We won't use the forward() method of the model in C++, so just ignore
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
|
@ -1,390 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""
|
|
||||||
This script loads ONNX models and uses them to decode waves.
|
|
||||||
You can use the following command to get the exported models:
|
|
||||||
|
|
||||||
./pruned_transducer_stateless2/export.py \
|
|
||||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
|
||||||
--lang-dir data/lang_char \
|
|
||||||
--epoch 20 \
|
|
||||||
--avg 10 \
|
|
||||||
--onnx 1
|
|
||||||
|
|
||||||
Usage of this script:
|
|
||||||
|
|
||||||
./pruned_transducer_stateless3/onnx_pretrained.py \
|
|
||||||
--encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \
|
|
||||||
--decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \
|
|
||||||
--joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \
|
|
||||||
--joiner-encoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_encoder_proj.onnx \
|
|
||||||
--joiner-decoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_decoder_proj.onnx \
|
|
||||||
--tokens data/lang_char/tokens.txt \
|
|
||||||
/path/to/foo.wav \
|
|
||||||
/path/to/bar.wav
|
|
||||||
|
|
||||||
We provide pretrained models at:
|
|
||||||
https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2/tree/main/exp
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import k2
|
|
||||||
import kaldifeat
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from icefall import is_module_available
|
|
||||||
|
|
||||||
if not is_module_available("onnxruntime"):
|
|
||||||
raise ValueError("Please 'pip install onnxruntime' first.")
|
|
||||||
|
|
||||||
import onnxruntime as ort
|
|
||||||
import torch
|
|
||||||
import torchaudio
|
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--encoder-model-filename",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path to the encoder onnx model. ",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--decoder-model-filename",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path to the decoder onnx model. ",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--joiner-model-filename",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path to the joiner onnx model. ",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--joiner-encoder-proj-model-filename",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path to the joiner encoder_proj onnx model. ",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--joiner-decoder-proj-model-filename",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path to the joiner decoder_proj onnx model. ",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--tokens",
|
|
||||||
type=str,
|
|
||||||
help="""Path to tokens.txt""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"sound_files",
|
|
||||||
type=str,
|
|
||||||
nargs="+",
|
|
||||||
help="The input sound file(s) to transcribe. "
|
|
||||||
"Supported formats are those supported by torchaudio.load(). "
|
|
||||||
"For example, wav and flac are supported. "
|
|
||||||
"The sample rate has to be 16kHz.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--sample-rate",
|
|
||||||
type=int,
|
|
||||||
default=16000,
|
|
||||||
help="The sample rate of the input sound file",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--context-size",
|
|
||||||
type=int,
|
|
||||||
default=2,
|
|
||||||
help="Context size of the decoder model",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def read_sound_files(
|
|
||||||
filenames: List[str], expected_sample_rate: float
|
|
||||||
) -> List[torch.Tensor]:
|
|
||||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
|
||||||
Args:
|
|
||||||
filenames:
|
|
||||||
A list of sound filenames.
|
|
||||||
expected_sample_rate:
|
|
||||||
The expected sample rate of the sound files.
|
|
||||||
Returns:
|
|
||||||
Return a list of 1-D float32 torch tensors.
|
|
||||||
"""
|
|
||||||
ans = []
|
|
||||||
for f in filenames:
|
|
||||||
wave, sample_rate = torchaudio.load(f)
|
|
||||||
assert (
|
|
||||||
sample_rate == expected_sample_rate
|
|
||||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
|
||||||
# We use only the first channel
|
|
||||||
ans.append(wave[0])
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
def greedy_search(
|
|
||||||
decoder: ort.InferenceSession,
|
|
||||||
joiner: ort.InferenceSession,
|
|
||||||
joiner_encoder_proj: ort.InferenceSession,
|
|
||||||
joiner_decoder_proj: ort.InferenceSession,
|
|
||||||
encoder_out: np.ndarray,
|
|
||||||
encoder_out_lens: np.ndarray,
|
|
||||||
context_size: int,
|
|
||||||
) -> List[List[int]]:
|
|
||||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
|
||||||
Args:
|
|
||||||
decoder:
|
|
||||||
The decoder model.
|
|
||||||
joiner:
|
|
||||||
The joiner model.
|
|
||||||
joiner_encoder_proj:
|
|
||||||
The joiner encoder projection model.
|
|
||||||
joiner_decoder_proj:
|
|
||||||
The joiner decoder projection model.
|
|
||||||
encoder_out:
|
|
||||||
A 3-D tensor of shape (N, T, C)
|
|
||||||
encoder_out_lens:
|
|
||||||
A 1-D tensor of shape (N,).
|
|
||||||
context_size:
|
|
||||||
The context size of the decoder model.
|
|
||||||
Returns:
|
|
||||||
Return the decoded results for each utterance.
|
|
||||||
"""
|
|
||||||
encoder_out = torch.from_numpy(encoder_out)
|
|
||||||
encoder_out_lens = torch.from_numpy(encoder_out_lens)
|
|
||||||
assert encoder_out.ndim == 3
|
|
||||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
|
||||||
|
|
||||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
|
||||||
input=encoder_out,
|
|
||||||
lengths=encoder_out_lens.cpu(),
|
|
||||||
batch_first=True,
|
|
||||||
enforce_sorted=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
projected_encoder_out = joiner_encoder_proj.run(
|
|
||||||
[joiner_encoder_proj.get_outputs()[0].name],
|
|
||||||
{joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()},
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
blank_id = 0 # hard-code to 0
|
|
||||||
|
|
||||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
|
||||||
N = encoder_out.size(0)
|
|
||||||
|
|
||||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
|
||||||
assert N == batch_size_list[0], (N, batch_size_list)
|
|
||||||
|
|
||||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
|
||||||
|
|
||||||
decoder_input_nodes = decoder.get_inputs()
|
|
||||||
decoder_output_nodes = decoder.get_outputs()
|
|
||||||
|
|
||||||
joiner_input_nodes = joiner.get_inputs()
|
|
||||||
joiner_output_nodes = joiner.get_outputs()
|
|
||||||
|
|
||||||
decoder_input = torch.tensor(
|
|
||||||
hyps,
|
|
||||||
dtype=torch.int64,
|
|
||||||
) # (N, context_size)
|
|
||||||
|
|
||||||
decoder_out = decoder.run(
|
|
||||||
[decoder_output_nodes[0].name],
|
|
||||||
{
|
|
||||||
decoder_input_nodes[0].name: decoder_input.numpy(),
|
|
||||||
},
|
|
||||||
)[0].squeeze(1)
|
|
||||||
projected_decoder_out = joiner_decoder_proj.run(
|
|
||||||
[joiner_decoder_proj.get_outputs()[0].name],
|
|
||||||
{joiner_decoder_proj.get_inputs()[0].name: decoder_out},
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
projected_decoder_out = torch.from_numpy(projected_decoder_out)
|
|
||||||
|
|
||||||
offset = 0
|
|
||||||
for batch_size in batch_size_list:
|
|
||||||
start = offset
|
|
||||||
end = offset + batch_size
|
|
||||||
current_encoder_out = projected_encoder_out[start:end]
|
|
||||||
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
|
||||||
offset = end
|
|
||||||
|
|
||||||
projected_decoder_out = projected_decoder_out[:batch_size]
|
|
||||||
|
|
||||||
logits = joiner.run(
|
|
||||||
[joiner_output_nodes[0].name],
|
|
||||||
{
|
|
||||||
joiner_input_nodes[0].name: current_encoder_out,
|
|
||||||
joiner_input_nodes[1].name: projected_decoder_out.numpy(),
|
|
||||||
},
|
|
||||||
)[0]
|
|
||||||
logits = torch.from_numpy(logits).squeeze(1).squeeze(1)
|
|
||||||
# logits'shape (batch_size, vocab_size)
|
|
||||||
|
|
||||||
assert logits.ndim == 2, logits.shape
|
|
||||||
y = logits.argmax(dim=1).tolist()
|
|
||||||
emitted = False
|
|
||||||
for i, v in enumerate(y):
|
|
||||||
if v != blank_id:
|
|
||||||
hyps[i].append(v)
|
|
||||||
emitted = True
|
|
||||||
if emitted:
|
|
||||||
# update decoder output
|
|
||||||
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
|
||||||
decoder_input = torch.tensor(
|
|
||||||
decoder_input,
|
|
||||||
dtype=torch.int64,
|
|
||||||
)
|
|
||||||
decoder_out = decoder.run(
|
|
||||||
[decoder_output_nodes[0].name],
|
|
||||||
{
|
|
||||||
decoder_input_nodes[0].name: decoder_input.numpy(),
|
|
||||||
},
|
|
||||||
)[0].squeeze(1)
|
|
||||||
projected_decoder_out = joiner_decoder_proj.run(
|
|
||||||
[joiner_decoder_proj.get_outputs()[0].name],
|
|
||||||
{joiner_decoder_proj.get_inputs()[0].name: decoder_out},
|
|
||||||
)[0]
|
|
||||||
projected_decoder_out = torch.from_numpy(projected_decoder_out)
|
|
||||||
|
|
||||||
sorted_ans = [h[context_size:] for h in hyps]
|
|
||||||
ans = []
|
|
||||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
|
||||||
for i in range(N):
|
|
||||||
ans.append(sorted_ans[unsorted_indices[i]])
|
|
||||||
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def main():
|
|
||||||
parser = get_parser()
|
|
||||||
args = parser.parse_args()
|
|
||||||
logging.info(vars(args))
|
|
||||||
|
|
||||||
session_opts = ort.SessionOptions()
|
|
||||||
session_opts.inter_op_num_threads = 1
|
|
||||||
session_opts.intra_op_num_threads = 1
|
|
||||||
|
|
||||||
encoder = ort.InferenceSession(
|
|
||||||
args.encoder_model_filename,
|
|
||||||
sess_options=session_opts,
|
|
||||||
)
|
|
||||||
|
|
||||||
decoder = ort.InferenceSession(
|
|
||||||
args.decoder_model_filename,
|
|
||||||
sess_options=session_opts,
|
|
||||||
)
|
|
||||||
|
|
||||||
joiner = ort.InferenceSession(
|
|
||||||
args.joiner_model_filename,
|
|
||||||
sess_options=session_opts,
|
|
||||||
)
|
|
||||||
|
|
||||||
joiner_encoder_proj = ort.InferenceSession(
|
|
||||||
args.joiner_encoder_proj_model_filename,
|
|
||||||
sess_options=session_opts,
|
|
||||||
)
|
|
||||||
|
|
||||||
joiner_decoder_proj = ort.InferenceSession(
|
|
||||||
args.joiner_decoder_proj_model_filename,
|
|
||||||
sess_options=session_opts,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("Constructing Fbank computer")
|
|
||||||
opts = kaldifeat.FbankOptions()
|
|
||||||
opts.device = "cpu"
|
|
||||||
opts.frame_opts.dither = 0
|
|
||||||
opts.frame_opts.snip_edges = False
|
|
||||||
opts.frame_opts.samp_freq = args.sample_rate
|
|
||||||
opts.mel_opts.num_bins = 80
|
|
||||||
|
|
||||||
fbank = kaldifeat.Fbank(opts)
|
|
||||||
|
|
||||||
logging.info(f"Reading sound files: {args.sound_files}")
|
|
||||||
waves = read_sound_files(
|
|
||||||
filenames=args.sound_files,
|
|
||||||
expected_sample_rate=args.sample_rate,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("Decoding started")
|
|
||||||
features = fbank(waves)
|
|
||||||
feature_lengths = [f.size(0) for f in features]
|
|
||||||
|
|
||||||
features = pad_sequence(
|
|
||||||
features,
|
|
||||||
batch_first=True,
|
|
||||||
padding_value=math.log(1e-10),
|
|
||||||
)
|
|
||||||
|
|
||||||
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
|
|
||||||
|
|
||||||
encoder_input_nodes = encoder.get_inputs()
|
|
||||||
encoder_out_nodes = encoder.get_outputs()
|
|
||||||
encoder_out, encoder_out_lens = encoder.run(
|
|
||||||
[encoder_out_nodes[0].name, encoder_out_nodes[1].name],
|
|
||||||
{
|
|
||||||
encoder_input_nodes[0].name: features.numpy(),
|
|
||||||
encoder_input_nodes[1].name: feature_lengths.numpy(),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
hyps = greedy_search(
|
|
||||||
decoder=decoder,
|
|
||||||
joiner=joiner,
|
|
||||||
joiner_encoder_proj=joiner_encoder_proj,
|
|
||||||
joiner_decoder_proj=joiner_decoder_proj,
|
|
||||||
encoder_out=encoder_out,
|
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
context_size=args.context_size,
|
|
||||||
)
|
|
||||||
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
|
||||||
s = "\n"
|
|
||||||
for filename, hyp in zip(args.sound_files, hyps):
|
|
||||||
words = "".join([symbol_table[i] for i in hyp])
|
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
|
||||||
main()
|
|
@ -0,0 +1 @@
|
|||||||
|
../pruned_transducer_stateless5/onnx_pretrained.py
|
Loading…
x
Reference in New Issue
Block a user