mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Add export to ONNX for Zipformer+CTC using blank skip (#861)
* Add export to ONNX for Zipformer+CTC using blank skip --------- Co-authored-by: yifanyang <yifanyeung@yifanyangs-MacBook-Pro.local>
This commit is contained in:
parent
e9019511eb
commit
d8234e199c
@ -72,14 +72,14 @@ Check ./pretrained.py for its usage.
|
||||
Note: If you don't want to train a model from scratch, we have
|
||||
provided one for you. You can get it at
|
||||
|
||||
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
|
||||
https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29
|
||||
|
||||
with the following commands:
|
||||
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
|
||||
# You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp
|
||||
git clone https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29
|
||||
# You will find the pre-trained model in icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29/exp
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
@ -0,0 +1,665 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Yifan Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script converts several saved checkpoints
|
||||
# to a single one using model averaging.
|
||||
"""
|
||||
|
||||
Usage:
|
||||
|
||||
(1) Export to ONNX format
|
||||
|
||||
./pruned_transducer_stateless7_ctc_bs/export_onnx.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 30 \
|
||||
--avg 13
|
||||
|
||||
It will generate the following files in the given `exp_dir`.
|
||||
Check `onnx_check.py` for how to use them.
|
||||
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
- joiner.onnx
|
||||
- joiner_encoder_proj.onnx
|
||||
- joiner_decoder_proj.onnx
|
||||
- lconv.onnx
|
||||
- frame_reducer.onnx
|
||||
|
||||
Please see ./onnx_pretrained.py for usage of the generated files
|
||||
|
||||
Check
|
||||
https://github.com/k2-fsa/sherpa-onnx
|
||||
for how to use the exported models outside of icefall.
|
||||
|
||||
Note: If you don't want to train a model from scratch, we have
|
||||
provided one for you. You can get it at
|
||||
|
||||
https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29
|
||||
|
||||
with the following commands:
|
||||
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
git clone https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29
|
||||
# You will find the pre-trained model in icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29/exp
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7_ctc_bs/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="""If True, --jit is ignored and it exports the model
|
||||
to onnx format. It will generate the following files:
|
||||
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
- joiner.onnx
|
||||
- joiner_encoder_proj.onnx
|
||||
- joiner_decoder_proj.onnx
|
||||
- lconv.onnx
|
||||
- frame_reducer.onnx
|
||||
|
||||
Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def export_encoder_model_onnx(
|
||||
encoder_model: nn.Module,
|
||||
encoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the given encoder model to ONNX format.
|
||||
The exported model has two inputs:
|
||||
|
||||
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
||||
- x_lens, a tensor of shape (N,); dtype is torch.int64
|
||||
|
||||
and it has two outputs:
|
||||
|
||||
- encoder_out, a tensor of shape (N, T, C)
|
||||
- encoder_out_lens, a tensor of shape (N,)
|
||||
|
||||
Note: The warmup argument is fixed to 1.
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
The input encoder model
|
||||
encoder_filename:
|
||||
The filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
x = torch.zeros(15, 2000, 80, dtype=torch.float32)
|
||||
x_lens = torch.tensor([2000] * 15, dtype=torch.int64)
|
||||
|
||||
# encoder_model = torch.jit.script(encoder_model)
|
||||
# It throws the following error for the above statement
|
||||
#
|
||||
# RuntimeError: Exporting the operator __is_ to ONNX opset version
|
||||
# 11 is not supported. Please feel free to request support or
|
||||
# submit a pull request on PyTorch GitHub.
|
||||
#
|
||||
# I cannot find which statement causes the above error.
|
||||
# torch.onnx.export() will use torch.jit.trace() internally, which
|
||||
# works well for the current reworked model
|
||||
torch.onnx.export(
|
||||
encoder_model,
|
||||
(x, x_lens),
|
||||
encoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x", "x_lens"],
|
||||
output_names=["encoder_out", "encoder_out_lens"],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "T"},
|
||||
"x_lens": {0: "N"},
|
||||
"encoder_out": {0: "N", 1: "T"},
|
||||
"encoder_out_lens": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {encoder_filename}")
|
||||
|
||||
|
||||
def export_decoder_model_onnx(
|
||||
decoder_model: nn.Module,
|
||||
decoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the decoder model to ONNX format.
|
||||
|
||||
The exported model has one input:
|
||||
|
||||
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||
|
||||
and has one output:
|
||||
|
||||
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
|
||||
|
||||
Note: The argument need_pad is fixed to False.
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The decoder model to be exported.
|
||||
decoder_filename:
|
||||
Filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
y = torch.zeros(15, decoder_model.context_size, dtype=torch.int64)
|
||||
need_pad = False # Always False, so we can use torch.jit.trace() here
|
||||
# Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script()
|
||||
# in this case
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
(y, need_pad),
|
||||
decoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["y", "need_pad"],
|
||||
output_names=["decoder_out"],
|
||||
dynamic_axes={
|
||||
"y": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
|
||||
def export_joiner_model_onnx(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the joiner model to ONNX format.
|
||||
The exported joiner model has two inputs:
|
||||
|
||||
- projected_encoder_out: a tensor of shape (N, joiner_dim)
|
||||
- projected_decoder_out: a tensor of shape (N, joiner_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- logit: a tensor of shape (N, vocab_size)
|
||||
|
||||
The exported encoder_proj model has one input:
|
||||
|
||||
- encoder_out: a tensor of shape (N, encoder_out_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- projected_encoder_out: a tensor of shape (N, joiner_dim)
|
||||
|
||||
The exported decoder_proj model has one input:
|
||||
|
||||
- decoder_out: a tensor of shape (N, decoder_out_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- projected_decoder_out: a tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
|
||||
decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
|
||||
|
||||
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||
joiner_dim = joiner_model.decoder_proj.weight.shape[0]
|
||||
|
||||
projected_encoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32)
|
||||
projected_decoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32)
|
||||
|
||||
project_input = False
|
||||
# Note: It uses torch.jit.trace() internally
|
||||
torch.onnx.export(
|
||||
joiner_model,
|
||||
(projected_encoder_out, projected_decoder_out, project_input),
|
||||
joiner_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=[
|
||||
"encoder_out",
|
||||
"decoder_out",
|
||||
"project_input",
|
||||
],
|
||||
output_names=["logit"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
"logit": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||
torch.onnx.export(
|
||||
joiner_model.encoder_proj,
|
||||
encoder_out,
|
||||
encoder_proj_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["encoder_out"],
|
||||
output_names=["projected_encoder_out"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"projected_encoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {encoder_proj_filename}")
|
||||
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
torch.onnx.export(
|
||||
joiner_model.decoder_proj,
|
||||
decoder_out,
|
||||
decoder_proj_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["decoder_out"],
|
||||
output_names=["projected_decoder_out"],
|
||||
dynamic_axes={
|
||||
"decoder_out": {0: "N"},
|
||||
"projected_decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {decoder_proj_filename}")
|
||||
|
||||
|
||||
def export_lconv_onnx(
|
||||
lconv: nn.Module,
|
||||
lconv_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the lconv to ONNX format.
|
||||
|
||||
The exported lconv has two inputs:
|
||||
|
||||
- lconv_input: a tensor of shape (N, T, C)
|
||||
- src_key_padding_mask: a tensor of shape (N, T)
|
||||
|
||||
and has one output:
|
||||
|
||||
- lconv_out: a tensor of shape (N, T, C)
|
||||
|
||||
Args:
|
||||
lconv:
|
||||
The lconv to be exported.
|
||||
lconv_filename:
|
||||
Filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
lconv_input = torch.zeros(15, 498, 384, dtype=torch.float32)
|
||||
src_key_padding_mask = torch.zeros(15, 498, dtype=torch.bool)
|
||||
|
||||
torch.onnx.export(
|
||||
lconv,
|
||||
(lconv_input, src_key_padding_mask),
|
||||
lconv_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["lconv_input", "src_key_padding_mask"],
|
||||
output_names=["lconv_out"],
|
||||
dynamic_axes={
|
||||
"lconv_input": {0: "N", 1: "T"},
|
||||
"src_key_padding_mask": {0: "N", 1: "T"},
|
||||
"lconv_out": {0: "N", 1: "T"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {lconv_filename}")
|
||||
|
||||
|
||||
def export_frame_reducer_onnx(
|
||||
frame_reducer: nn.Module,
|
||||
frame_reducer_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the frame_reducer to ONNX format.
|
||||
|
||||
The exported frame_reducer has four inputs:
|
||||
|
||||
- x: a tensor of shape (N, T, C)
|
||||
- x_lens: a tensor of shape (N, T)
|
||||
- ctc_output: a tensor of shape (N, T, vocab_size)
|
||||
- blank_id: an int, always 0
|
||||
|
||||
and has two outputs:
|
||||
|
||||
- x_fr: a tensor of shape (N, T, C)
|
||||
- x_lens_fr: a tensor of shape (N, T)
|
||||
|
||||
Args:
|
||||
frame_reducer:
|
||||
The frame_reducer to be exported.
|
||||
frame_reducer_filename:
|
||||
Filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
x = torch.zeros(15, 498, 384, dtype=torch.float32)
|
||||
x_lens = torch.tensor([498] * 15, dtype=torch.int64)
|
||||
ctc_output = torch.randn(15, 498, 500, dtype=torch.float32)
|
||||
|
||||
torch.onnx.export(
|
||||
frame_reducer,
|
||||
(x, x_lens, ctc_output),
|
||||
frame_reducer_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x", "x_lens", "ctc_output"],
|
||||
output_names=["out", "out_lens"],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "T"},
|
||||
"x_lens": {0: "N"},
|
||||
"ctc_output": {0: "N", 1: "T"},
|
||||
"out": {0: "N", 1: "T"},
|
||||
"out_lens": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {frame_reducer_filename}")
|
||||
|
||||
|
||||
def export_ctc_output_onnx(
|
||||
ctc_output: nn.Module,
|
||||
ctc_output_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the frame_reducer to ONNX format.
|
||||
|
||||
The exported frame_reducer has one inputs:
|
||||
|
||||
- encoder_out: a tensor of shape (N, T, C)
|
||||
|
||||
and has one output:
|
||||
|
||||
- ctc_output: a tensor of shape (N, T, vocab_size)
|
||||
|
||||
Args:
|
||||
ctc_output:
|
||||
The ctc_output to be exported.
|
||||
ctc_output_filename:
|
||||
Filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
encoder_out = torch.zeros(15, 498, 384, dtype=torch.float32)
|
||||
|
||||
torch.onnx.export(
|
||||
ctc_output,
|
||||
(encoder_out),
|
||||
ctc_output_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["encoder_out"],
|
||||
output_names=["ctc_output"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N", 1: "T"},
|
||||
"ctc_output": {0: "N", 1: "T"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {ctc_output_filename}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
opset_version = 13
|
||||
logging.info("Exporting to onnx format")
|
||||
encoder_filename = params.exp_dir / "encoder.onnx"
|
||||
export_encoder_model_onnx(
|
||||
model.encoder,
|
||||
encoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
decoder_filename = params.exp_dir / "decoder.onnx"
|
||||
export_decoder_model_onnx(
|
||||
model.decoder,
|
||||
decoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
joiner_filename = params.exp_dir / "joiner.onnx"
|
||||
export_joiner_model_onnx(
|
||||
model.joiner,
|
||||
joiner_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
lconv_filename = params.exp_dir / "lconv.onnx"
|
||||
export_lconv_onnx(
|
||||
model.lconv,
|
||||
lconv_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
frame_reducer_filename = params.exp_dir / "frame_reducer.onnx"
|
||||
export_frame_reducer_onnx(
|
||||
model.frame_reducer,
|
||||
frame_reducer_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
ctc_output_filename = params.exp_dir / "ctc_output.onnx"
|
||||
export_ctc_output_onnx(
|
||||
model.ctc_output,
|
||||
ctc_output_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
76
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py
Executable file → Normal file
76
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py
Executable file → Normal file
@ -22,7 +22,8 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
import torch.nn.functional as F
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
|
||||
@ -43,7 +44,6 @@ class FrameReducer(nn.Module):
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
ctc_output: torch.Tensor,
|
||||
blank_id: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
@ -54,26 +54,68 @@ class FrameReducer(nn.Module):
|
||||
`x` before padding.
|
||||
ctc_output:
|
||||
The CTC output with shape [N, T, vocab_size].
|
||||
blank_id:
|
||||
The ID of the blank symbol.
|
||||
Returns:
|
||||
x_fr:
|
||||
out:
|
||||
The frame reduced encoder output with shape [N, T', C].
|
||||
x_lens_fr:
|
||||
out_lens:
|
||||
A tensor of shape (batch_size,) containing the number of frames in
|
||||
`x_fr` before padding.
|
||||
`out` before padding.
|
||||
"""
|
||||
|
||||
N, T, C = x.size()
|
||||
|
||||
padding_mask = make_pad_mask(x_lens)
|
||||
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
|
||||
non_blank_mask = (ctc_output[:, :, 0] < math.log(0.9)) * (~padding_mask)
|
||||
|
||||
frames_list: List[torch.Tensor] = []
|
||||
lens_list: List[int] = []
|
||||
for i in range(x.shape[0]):
|
||||
frames = x[i][non_blank_mask[i]]
|
||||
frames_list.append(frames)
|
||||
lens_list.append(frames.shape[0])
|
||||
x_fr = pad_sequence(frames_list, batch_first=True)
|
||||
x_lens_fr = torch.tensor(lens_list).to(device=x.device)
|
||||
out_lens = non_blank_mask.sum(dim=1)
|
||||
max_len = out_lens.max()
|
||||
pad_lens_list = torch.full_like(out_lens, max_len.item()) - out_lens
|
||||
max_pad_len = pad_lens_list.max()
|
||||
|
||||
return x_fr, x_lens_fr
|
||||
out = F.pad(x, (0, 0, 0, max_pad_len))
|
||||
|
||||
valid_pad_mask = ~make_pad_mask(pad_lens_list)
|
||||
total_valid_mask = torch.concat([non_blank_mask, valid_pad_mask], dim=1)
|
||||
|
||||
out = out[total_valid_mask].reshape(N, -1, C)
|
||||
|
||||
return out.to(device=x.device), out_lens.to(device=x.device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
test_times = 10000
|
||||
frame_reducer = FrameReducer()
|
||||
|
||||
# non zero case
|
||||
x = torch.ones(15, 498, 384, dtype=torch.float32)
|
||||
x_lens = torch.tensor([498] * 15, dtype=torch.int64)
|
||||
ctc_output = torch.log(torch.randn(15, 498, 500, dtype=torch.float32))
|
||||
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output)
|
||||
|
||||
avg_time = 0
|
||||
for i in range(test_times):
|
||||
delta_time = time.time()
|
||||
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output)
|
||||
delta_time = time.time() - delta_time
|
||||
avg_time += delta_time
|
||||
print(x_fr.shape)
|
||||
print(x_lens_fr)
|
||||
print(avg_time / test_times)
|
||||
|
||||
# all zero case
|
||||
x = torch.zeros(15, 498, 384, dtype=torch.float32)
|
||||
x_lens = torch.tensor([498] * 15, dtype=torch.int64)
|
||||
ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32)
|
||||
|
||||
avg_time = 0
|
||||
for i in range(test_times):
|
||||
delta_time = time.time()
|
||||
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output)
|
||||
delta_time = time.time() - delta_time
|
||||
avg_time += delta_time
|
||||
print(x_fr.shape)
|
||||
print(x_lens_fr)
|
||||
print(avg_time / test_times)
|
||||
|
@ -0,0 +1,461 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Yifan Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads ONNX models and uses them to decode waves.
|
||||
You can use the following command to get the exported models:
|
||||
|
||||
./pruned_transducer_stateless7_ctc_bs/export_onnx.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 30 \
|
||||
--avg 13
|
||||
|
||||
Usage of this script:
|
||||
|
||||
./pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py \
|
||||
--encoder-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/encoder.onnx \
|
||||
--decoder-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/decoder.onnx \
|
||||
--joiner-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/joiner.onnx \
|
||||
--joiner-encoder-proj-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/joiner_encoder_proj.onnx \
|
||||
--joiner-decoder-proj-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/joiner_decoder_proj.onnx \
|
||||
--lconv-filename ./pruned_transducer_stateless7_ctc_bs/exp/lconv.onnx \
|
||||
--frame-reducer-filename ./pruned_transducer_stateless7_ctc_bs/exp/frame_reducer.onnx \
|
||||
--ctc-output-filename ./pruned_transducer_stateless7_ctc_bs/exp/ctc_output.onnx \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import kaldifeat
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the encoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the decoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-encoder-proj-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner encoder_proj onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-decoder-proj-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner decoder_proj onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lconv-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the lconv onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--frame-reducer-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the frame reducer onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ctc-output-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the ctc_output onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Context size of the decoder model",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
def greedy_search(
|
||||
decoder: ort.InferenceSession,
|
||||
joiner: ort.InferenceSession,
|
||||
joiner_encoder_proj: ort.InferenceSession,
|
||||
joiner_decoder_proj: ort.InferenceSession,
|
||||
encoder_out: np.ndarray,
|
||||
encoder_out_lens: np.ndarray,
|
||||
context_size: int,
|
||||
) -> List[List[int]]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
decoder:
|
||||
The decoder model.
|
||||
joiner:
|
||||
The joiner model.
|
||||
joiner_encoder_proj:
|
||||
The joiner encoder projection model.
|
||||
joiner_decoder_proj:
|
||||
The joiner decoder projection model.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,).
|
||||
context_size:
|
||||
The context size of the decoder model.
|
||||
Returns:
|
||||
Return the decoded results for each utterance.
|
||||
"""
|
||||
encoder_out = torch.from_numpy(encoder_out)
|
||||
encoder_out_lens = torch.from_numpy(encoder_out_lens)
|
||||
assert encoder_out.ndim == 3
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
|
||||
packed_encoder_out = pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
lengths=encoder_out_lens.cpu(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False,
|
||||
)
|
||||
|
||||
projected_encoder_out = joiner_encoder_proj.run(
|
||||
[joiner_encoder_proj.get_outputs()[0].name],
|
||||
{joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()},
|
||||
)[0]
|
||||
|
||||
blank_id = 0 # hard-code to 0
|
||||
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||
|
||||
decoder_input_nodes = decoder.get_inputs()
|
||||
decoder_output_nodes = decoder.get_outputs()
|
||||
|
||||
joiner_input_nodes = joiner.get_inputs()
|
||||
joiner_output_nodes = joiner.get_outputs()
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
hyps,
|
||||
dtype=torch.int64,
|
||||
) # (N, context_size)
|
||||
|
||||
decoder_out = decoder.run(
|
||||
[decoder_output_nodes[0].name],
|
||||
{
|
||||
decoder_input_nodes[0].name: decoder_input.numpy(),
|
||||
},
|
||||
)[0].squeeze(1)
|
||||
projected_decoder_out = joiner_decoder_proj.run(
|
||||
[joiner_decoder_proj.get_outputs()[0].name],
|
||||
{joiner_decoder_proj.get_inputs()[0].name: decoder_out},
|
||||
)[0]
|
||||
|
||||
projected_decoder_out = torch.from_numpy(projected_decoder_out)
|
||||
|
||||
offset = 0
|
||||
for batch_size in batch_size_list:
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = projected_encoder_out[start:end]
|
||||
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||
offset = end
|
||||
|
||||
projected_decoder_out = projected_decoder_out[:batch_size]
|
||||
|
||||
logits = joiner.run(
|
||||
[joiner_output_nodes[0].name],
|
||||
{
|
||||
joiner_input_nodes[0].name: np.expand_dims(
|
||||
np.expand_dims(current_encoder_out, axis=1), axis=1
|
||||
),
|
||||
joiner_input_nodes[1]
|
||||
.name: projected_decoder_out.unsqueeze(1)
|
||||
.unsqueeze(1)
|
||||
.numpy(),
|
||||
},
|
||||
)[0]
|
||||
logits = torch.from_numpy(logits).squeeze(1).squeeze(1)
|
||||
# logits'shape (batch_size, vocab_size)
|
||||
|
||||
assert logits.ndim == 2, logits.shape
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
if v != blank_id:
|
||||
hyps[i].append(v)
|
||||
emitted = True
|
||||
if emitted:
|
||||
# update decoder output
|
||||
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||
decoder_input = torch.tensor(
|
||||
decoder_input,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = decoder.run(
|
||||
[decoder_output_nodes[0].name],
|
||||
{
|
||||
decoder_input_nodes[0].name: decoder_input.numpy(),
|
||||
},
|
||||
)[0].squeeze(1)
|
||||
projected_decoder_out = joiner_decoder_proj.run(
|
||||
[joiner_decoder_proj.get_outputs()[0].name],
|
||||
{joiner_decoder_proj.get_inputs()[0].name: decoder_out},
|
||||
)[0]
|
||||
projected_decoder_out = torch.from_numpy(projected_decoder_out)
|
||||
|
||||
sorted_ans = [h[context_size:] for h in hyps]
|
||||
ans = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
encoder = ort.InferenceSession(
|
||||
args.encoder_model_filename,
|
||||
sess_options=session_opts,
|
||||
)
|
||||
|
||||
decoder = ort.InferenceSession(
|
||||
args.decoder_model_filename,
|
||||
sess_options=session_opts,
|
||||
)
|
||||
|
||||
joiner = ort.InferenceSession(
|
||||
args.joiner_model_filename,
|
||||
sess_options=session_opts,
|
||||
)
|
||||
|
||||
joiner_encoder_proj = ort.InferenceSession(
|
||||
args.joiner_encoder_proj_model_filename,
|
||||
sess_options=session_opts,
|
||||
)
|
||||
|
||||
joiner_decoder_proj = ort.InferenceSession(
|
||||
args.joiner_decoder_proj_model_filename,
|
||||
sess_options=session_opts,
|
||||
)
|
||||
|
||||
lconv = ort.InferenceSession(
|
||||
args.lconv_filename,
|
||||
sess_options=session_opts,
|
||||
)
|
||||
|
||||
frame_reducer = ort.InferenceSession(
|
||||
args.frame_reducer_filename,
|
||||
sess_options=session_opts,
|
||||
)
|
||||
|
||||
ctc_output = ort.InferenceSession(
|
||||
args.ctc_output_filename,
|
||||
sess_options=session_opts,
|
||||
)
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(args.bpe_model)
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = "cpu"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = args.sample_rate
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=args.sound_files,
|
||||
expected_sample_rate=args.sample_rate,
|
||||
)
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features,
|
||||
batch_first=True,
|
||||
padding_value=math.log(1e-10),
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
|
||||
|
||||
encoder_input_nodes = encoder.get_inputs()
|
||||
encoder_out_nodes = encoder.get_outputs()
|
||||
encoder_out, encoder_out_lens = encoder.run(
|
||||
[encoder_out_nodes[0].name, encoder_out_nodes[1].name],
|
||||
{
|
||||
encoder_input_nodes[0].name: features.numpy(),
|
||||
encoder_input_nodes[1].name: feature_lengths.numpy(),
|
||||
},
|
||||
)
|
||||
|
||||
ctc_output_input_nodes = ctc_output.get_inputs()
|
||||
ctc_output_out_nodes = ctc_output.get_outputs()
|
||||
ctc_out = ctc_output.run(
|
||||
[ctc_output_out_nodes[0].name],
|
||||
{
|
||||
ctc_output_input_nodes[0].name: encoder_out,
|
||||
},
|
||||
)[0]
|
||||
|
||||
lconv_input_nodes = lconv.get_inputs()
|
||||
lconv_out_nodes = lconv.get_outputs()
|
||||
encoder_out = lconv.run(
|
||||
[lconv_out_nodes[0].name],
|
||||
{
|
||||
lconv_input_nodes[0].name: encoder_out,
|
||||
lconv_input_nodes[1]
|
||||
.name: make_pad_mask(torch.from_numpy(encoder_out_lens))
|
||||
.numpy(),
|
||||
},
|
||||
)[0]
|
||||
|
||||
frame_reducer_input_nodes = frame_reducer.get_inputs()
|
||||
frame_reducer_out_nodes = frame_reducer.get_outputs()
|
||||
encoder_out_fr, encoder_out_lens_fr = frame_reducer.run(
|
||||
[frame_reducer_out_nodes[0].name, frame_reducer_out_nodes[1].name],
|
||||
{
|
||||
frame_reducer_input_nodes[0].name: encoder_out,
|
||||
frame_reducer_input_nodes[1].name: encoder_out_lens,
|
||||
frame_reducer_input_nodes[2].name: ctc_out,
|
||||
},
|
||||
)
|
||||
|
||||
hyps = greedy_search(
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
joiner_encoder_proj=joiner_encoder_proj,
|
||||
joiner_decoder_proj=joiner_decoder_proj,
|
||||
encoder_out=encoder_out_fr,
|
||||
encoder_out_lens=encoder_out_lens_fr,
|
||||
context_size=args.context_size,
|
||||
)
|
||||
s = "\n"
|
||||
for filename, hyp in zip(args.sound_files, hyps):
|
||||
words = sp.decode(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user