support decoding and computing RTF on test sets with onnx models (#995)

* support decode and compute RTF on test sets with onnx models

* support onnx export and decode in pruned_transducer_stateless
This commit is contained in:
Zengwei Yao 2023-04-12 19:04:50 +08:00 committed by GitHub
parent dbf2aa3212
commit 5f066d3d53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1901 additions and 29 deletions

View File

@ -0,0 +1,527 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
"""
This script exports a transducer model from PyTorch to ONNX.
We use the pre-trained model from
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-9999.pt
popd
2. Export the model to ONNX
./pruned_transducer_stateless/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 9999 \
--avg 1 \
--exp-dir $repo/exp/
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-9999-avg-1.onnx
- decoder-epoch-9999-avg-1.onnx
- joiner-epoch-9999-avg-1.onnx
See ./onnx_pretrained.py and ./onnx_check.py for how to
use the exported ONNX models.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict, Tuple
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from conformer import Conformer
from decoder import Decoder
from onnxruntime.quantization import QuantType, quantize_dynamic
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.utils import setup_logger
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
onnx.save(model, filename)
class OnnxEncoder(nn.Module):
"""A wrapper for Conformer"""
def __init__(self, encoder: Conformer):
"""
Args:
encoder:
A Conformer encoder.
"""
super().__init__()
self.encoder = encoder
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Please see the help information of Conformer.forward
Args:
x:
A 3-D tensor of shape (N, T, C)
x_lens:
A 1-D tensor of shape (N,). Its dtype is torch.int64
Returns:
Return a tuple containing:
- encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
- encoder_out_lens, A 1-D tensor of shape (N,)
"""
encoder_out, encoder_out_lens = self.encoder(x, x_lens)
# Now encoder_out is of shape (N, T, joiner_dim)
return encoder_out, encoder_out_lens
class OnnxDecoder(nn.Module):
"""A wrapper for Decoder"""
def __init__(self, decoder: Decoder):
super().__init__()
self.decoder = decoder
def forward(self, y: torch.Tensor) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, context_size).
Returns
Return a 2-D tensor of shape (N, joiner_dim)
"""
need_pad = False
decoder_output = self.decoder(y, need_pad=need_pad)
output = decoder_output.squeeze(1)
return output
class OnnxJoiner(nn.Module):
"""A wrapper for the joiner"""
def __init__(self, inner_linear: nn.Linear, output_linear: nn.Linear):
super().__init__()
self.inner_linear = inner_linear
self.output_linear = output_linear
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
) -> torch.Tensor:
"""
Args:
encoder_out:
A 2-D tensor of shape (N, joiner_dim)
decoder_out:
A 2-D tensor of shape (N, joiner_dim)
Returns:
Return a 2-D tensor of shape (N, vocab_size)
"""
logit = encoder_out + decoder_out
logit = self.inner_linear(torch.tanh(logit))
output = self.output_linear(nn.functional.relu(logit))
return output
def export_encoder_model_onnx(
encoder_model: OnnxEncoder,
encoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the given encoder model to ONNX format.
The exported model has two inputs:
- x, a tensor of shape (N, T, C); dtype is torch.float32
- x_lens, a tensor of shape (N,); dtype is torch.int64
and it has two outputs:
- encoder_out, a tensor of shape (N, T', joiner_dim)
- encoder_out_lens, a tensor of shape (N,)
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
x = torch.zeros(1, 100, 80, dtype=torch.float32)
x_lens = torch.tensor([100], dtype=torch.int64)
torch.onnx.export(
encoder_model,
(x, x_lens),
encoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "x_lens"],
output_names=["encoder_out", "encoder_out_lens"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"encoder_out": {0: "N", 1: "T"},
"encoder_out_lens": {0: "N"},
},
)
meta_data = {
"model_type": "conformer",
"version": "1",
"model_author": "k2-fsa",
"comment": "stateless3",
}
logging.info(f"meta_data: {meta_data}")
add_meta_data(filename=encoder_filename, meta_data=meta_data)
def export_decoder_model_onnx(
decoder_model: OnnxDecoder,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.
The exported model has one input:
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
and has one output:
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
context_size = decoder_model.decoder.context_size
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
torch.onnx.export(
decoder_model,
y,
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
)
meta_data = {
"context_size": str(context_size),
"vocab_size": str(vocab_size),
}
add_meta_data(filename=decoder_filename, meta_data=meta_data)
def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported joiner model has two inputs:
- encoder_out: a tensor of shape (N, joiner_dim)
- decoder_out: a tensor of shape (N, joiner_dim)
and produces one output:
- logit: a tensor of shape (N, vocab_size)
"""
joiner_dim = joiner_model.inner_linear.weight.shape[1]
logging.info(f"joiner dim: {joiner_dim}")
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model,
(projected_encoder_out, projected_decoder_out),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"encoder_out",
"decoder_out",
],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
meta_data = {
"joiner_dim": str(joiner_dim),
}
add_meta_data(filename=joiner_filename, meta_data=meta_data)
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
setup_logger(f"{params.exp_dir}/log-export/log-export-onnx")
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
model.to("cpu")
model.eval()
encoder = OnnxEncoder(encoder=model.encoder)
decoder = OnnxDecoder(decoder=model.decoder)
joiner = OnnxJoiner(
inner_linear=model.joiner.inner_linear, output_linear=model.joiner.output_linear
)
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
logging.info(f"encoder parameters: {encoder_num_param}")
logging.info(f"decoder parameters: {decoder_num_param}")
logging.info(f"joiner parameters: {joiner_num_param}")
logging.info(f"total parameters: {total_num_param}")
if params.iter > 0:
suffix = f"iter-{params.iter}"
else:
suffix = f"epoch-{params.epoch}"
suffix += f"-avg-{params.avg}"
opset_version = 13
logging.info("Exporting encoder")
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
export_encoder_model_onnx(
encoder,
encoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported encoder to {encoder_filename}")
logging.info("Exporting decoder")
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
export_decoder_model_onnx(
decoder,
decoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported decoder to {decoder_filename}")
logging.info("Exporting joiner")
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
export_joiner_model_onnx(
joiner,
joiner_filename,
opset_version=opset_version,
)
logging.info(f"Exported joiner to {joiner_filename}")
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
logging.info("Generate int8 quantization models")
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=encoder_filename,
model_output=encoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
quantize_dynamic(
model_input=joiner_filename,
model_output=joiner_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless3/onnx_check.py

View File

@ -0,0 +1,319 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads ONNX exported models and uses them to decode the test sets.
We use the pre-trained model from
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-9999.pt
popd
2. Export the model to ONNX
./pruned_transducer_stateless/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 9999 \
--avg 1 \
--exp-dir $repo/exp/
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-9999-avg-1.onnx
- decoder-epoch-9999-avg-1.onnx
- joiner-epoch-9999-avg-1.onnx
3. Run this file
./pruned_transducer_stateless/onnx_decode.py \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 600 \
--encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \
"""
import argparse
import logging
import time
from pathlib import Path
from typing import List, Tuple
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from onnx_pretrained import greedy_search, OnnxModel
from icefall.utils import setup_logger, store_transcripts, write_error_stats
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder onnx model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder onnx model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner onnx model. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="Valid values are greedy_search and modified_beam_search",
)
return parser
def decode_one_batch(
model: OnnxModel, sp: spm.SentencePieceProcessor, batch: dict
) -> List[List[str]]:
"""Decode one batch and return the result.
Currently it only greedy_search is supported.
Args:
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
Returns:
Return the decoded results for each utterance.
"""
feature = batch["inputs"]
assert feature.ndim == 3
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(dtype=torch.int64)
encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens)
hyps = greedy_search(
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
)
hyps = [sp.decode(h).split() for h in hyps]
return hyps
def decode_dataset(
dl: torch.utils.data.DataLoader,
model: nn.Module,
sp: spm.SentencePieceProcessor,
) -> Tuple[List[Tuple[str, List[str], List[str]]], float]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
model:
The neural model.
sp:
The BPE model.
Returns:
- A list of tuples. Each tuple contains three elements:
- cut_id,
- reference transcript,
- predicted result.
- The total duration (in seconds) of the dataset.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
log_interval = 10
total_duration = 0
results = []
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]])
hyps = decode_one_batch(model=model, sp=sp, batch=batch)
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results.extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results, total_duration
def save_results(
res_dir: Path,
test_set_name: str,
results: List[Tuple[str, List[str], List[str]]],
):
recog_path = res_dir / f"recogs-{test_set_name}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = res_dir / f"errs-{test_set_name}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True)
logging.info("Wrote detailed error stats to {}".format(errs_filename))
errs_info = res_dir / f"wer-summary-{test_set_name}.txt"
with open(errs_info, "w") as f:
print("WER", file=f)
print(wer, file=f)
s = "\nFor {}, WER is {}:\n".format(test_set_name, wer)
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
assert (
args.decoding_method == "greedy_search"
), "Only supports greedy_search currently."
res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}"
setup_logger(f"{res_dir}/log-decode")
logging.info("Decoding started")
device = torch.device("cpu")
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
blank_id = sp.piece_to_id("<blk>")
assert blank_id == 0, blank_id
logging.info(vars(args))
logging.info("About to create model")
model = OnnxModel(
encoder_model_filename=args.encoder_model_filename,
decoder_model_filename=args.decoder_model_filename,
joiner_model_filename=args.joiner_model_filename,
)
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
start_time = time.time()
results, total_duration = decode_dataset(dl=test_dl, model=model, sp=sp)
end_time = time.time()
elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / total_duration
logging.info(f"Elapsed time: {elapsed_seconds:.3f} s")
logging.info(f"Wave duration: {total_duration:.3f} s")
logging.info(
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
)
save_results(res_dir=res_dir, test_set_name=test_set, results=results)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless3/onnx_pretrained.py

View File

@ -0,0 +1,321 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads ONNX exported models and uses them to decode the test sets.
We use the pre-trained model from
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-iter-1224000-avg-14.pt"
cd exp
ln -s pretrained-iter-1224000-avg-14.pt epoch-9999.pt
popd
2. Export the model to ONNX
./pruned_transducer_stateless3/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 9999 \
--avg 1 \
--exp-dir $repo/exp/
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-9999-avg-1.onnx
- decoder-epoch-9999-avg-1.onnx
- joiner-epoch-9999-avg-1.onnx
2. Run this file
./pruned_transducer_stateless3/onnx_decode.py \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \
"""
import argparse
import logging
import time
from pathlib import Path
from typing import List, Tuple
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import AsrDataModule
from librispeech import LibriSpeech
from onnx_pretrained import greedy_search, OnnxModel
from icefall.utils import setup_logger, store_transcripts, write_error_stats
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder onnx model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder onnx model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner onnx model. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless3/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="Valid values are greedy_search and modified_beam_search",
)
return parser
def decode_one_batch(
model: OnnxModel, sp: spm.SentencePieceProcessor, batch: dict
) -> List[List[str]]:
"""Decode one batch and return the result.
Currently it only greedy_search is supported.
Args:
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
Returns:
Return the decoded results for each utterance.
"""
feature = batch["inputs"]
assert feature.ndim == 3
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(dtype=torch.int64)
encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens)
hyps = greedy_search(
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
)
hyps = [sp.decode(h).split() for h in hyps]
return hyps
def decode_dataset(
dl: torch.utils.data.DataLoader,
model: nn.Module,
sp: spm.SentencePieceProcessor,
) -> Tuple[List[Tuple[str, List[str], List[str]]], float]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
model:
The neural model.
sp:
The BPE model.
Returns:
- A list of tuples. Each tuple contains three elements:
- cut_id,
- reference transcript,
- predicted result.
- The total duration (in seconds) of the dataset.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
log_interval = 10
total_duration = 0
results = []
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]])
hyps = decode_one_batch(model=model, sp=sp, batch=batch)
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results.extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results, total_duration
def save_results(
res_dir: Path,
test_set_name: str,
results: List[Tuple[str, List[str], List[str]]],
):
recog_path = res_dir / f"recogs-{test_set_name}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = res_dir / f"errs-{test_set_name}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True)
logging.info("Wrote detailed error stats to {}".format(errs_filename))
errs_info = res_dir / f"wer-summary-{test_set_name}.txt"
with open(errs_info, "w") as f:
print("WER", file=f)
print(wer, file=f)
s = "\nFor {}, WER is {}:\n".format(test_set_name, wer)
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
AsrDataModule.add_arguments(parser)
args = parser.parse_args()
assert (
args.decoding_method == "greedy_search"
), "Only supports greedy_search currently."
res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}"
setup_logger(f"{res_dir}/log-decode")
logging.info("Decoding started")
device = torch.device("cpu")
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
blank_id = sp.piece_to_id("<blk>")
assert blank_id == 0, blank_id
logging.info(vars(args))
logging.info("About to create model")
model = OnnxModel(
encoder_model_filename=args.encoder_model_filename,
decoder_model_filename=args.decoder_model_filename,
joiner_model_filename=args.joiner_model_filename,
)
# we need cut ids to display recognition results.
args.return_cuts = True
asr_datamodule = AsrDataModule(args)
librispeech = LibriSpeech(manifest_dir=args.manifest_dir)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = asr_datamodule.test_dataloaders(test_clean_cuts)
test_other_dl = asr_datamodule.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
start_time = time.time()
results, total_duration = decode_dataset(dl=test_dl, model=model, sp=sp)
end_time = time.time()
elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / total_duration
logging.info(f"Elapsed time: {elapsed_seconds:.3f} s")
logging.info(f"Wave duration: {total_duration:.3f} s")
logging.info(
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
)
save_results(res_dir=res_dir, test_set_name=test_set, results=results)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -71,7 +71,6 @@ from typing import List, Tuple
import k2
import kaldifeat
import numpy as np
import onnxruntime as ort
import torch
import torchaudio
@ -139,7 +138,7 @@ class OnnxModel:
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
session_opts.intra_op_num_threads = 4
self.session_opts = session_opts

View File

@ -60,6 +60,7 @@ import sentencepiece as spm
import torch
import torch.nn as nn
from conformer import Conformer
from onnxruntime.quantization import QuantType, quantize_dynamic
from decoder import Decoder
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
@ -568,6 +569,35 @@ def main():
)
logging.info(f"Exported joiner to {joiner_filename}")
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
logging.info("Generate int8 quantization models")
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=encoder_filename,
model_output=encoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
quantize_dynamic(
model_input=joiner_filename,
model_output=joiner_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

View File

@ -0,0 +1,326 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads ONNX exported models and uses them to decode the test sets.
We use the pre-trained model from
https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-2022-07-07
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-2022-07-07
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-epoch-30-avg-10.pt"
cd exp
ln -s pretrained-epoch-30-avg-10.pt epoch-9999.pt
popd
2. Export the model to ONNX
./pruned_transducer_stateless5/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 9999 \
--avg 1 \
--use-averaged-model 0 \
--exp-dir $repo/exp/ \
--num-encoder-layers 18 \
--dim-feedforward 2048 \
--nhead 8 \
--encoder-dim 512 \
--decoder-dim 512 \
--joiner-dim 512
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-9999-avg-1.onnx
- decoder-epoch-9999-avg-1.onnx
- joiner-epoch-9999-avg-1.onnx
2. Run this file
./pruned_transducer_stateless5/onnx_decode.py \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \
--encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \
"""
import argparse
import logging
import time
from pathlib import Path
from typing import List, Tuple
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from onnx_pretrained import greedy_search, OnnxModel
from icefall.utils import setup_logger, store_transcripts, write_error_stats
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder onnx model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder onnx model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner onnx model. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless5/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="Valid values are greedy_search and modified_beam_search",
)
return parser
def decode_one_batch(
model: OnnxModel, sp: spm.SentencePieceProcessor, batch: dict
) -> List[List[str]]:
"""Decode one batch and return the result.
Currently it only greedy_search is supported.
Args:
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
Returns:
Return the decoded results for each utterance.
"""
feature = batch["inputs"]
assert feature.ndim == 3
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(dtype=torch.int64)
encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens)
hyps = greedy_search(
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
)
hyps = [sp.decode(h).split() for h in hyps]
return hyps
def decode_dataset(
dl: torch.utils.data.DataLoader,
model: nn.Module,
sp: spm.SentencePieceProcessor,
) -> Tuple[List[Tuple[str, List[str], List[str]]], float]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
model:
The neural model.
sp:
The BPE model.
Returns:
- A list of tuples. Each tuple contains three elements:
- cut_id,
- reference transcript,
- predicted result.
- The total duration (in seconds) of the dataset.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
log_interval = 10
total_duration = 0
results = []
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]])
hyps = decode_one_batch(model=model, sp=sp, batch=batch)
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results.extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results, total_duration
def save_results(
res_dir: Path,
test_set_name: str,
results: List[Tuple[str, List[str], List[str]]],
):
recog_path = res_dir / f"recogs-{test_set_name}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = res_dir / f"errs-{test_set_name}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True)
logging.info("Wrote detailed error stats to {}".format(errs_filename))
errs_info = res_dir / f"wer-summary-{test_set_name}.txt"
with open(errs_info, "w") as f:
print("WER", file=f)
print(wer, file=f)
s = "\nFor {}, WER is {}:\n".format(test_set_name, wer)
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
assert (
args.decoding_method == "greedy_search"
), "Only supports greedy_search currently."
res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}"
setup_logger(f"{res_dir}/log-decode")
logging.info("Decoding started")
device = torch.device("cpu")
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
blank_id = sp.piece_to_id("<blk>")
assert blank_id == 0, blank_id
logging.info(vars(args))
logging.info("About to create model")
model = OnnxModel(
encoder_model_filename=args.encoder_model_filename,
decoder_model_filename=args.decoder_model_filename,
joiner_model_filename=args.joiner_model_filename,
)
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
start_time = time.time()
results, total_duration = decode_dataset(dl=test_dl, model=model, sp=sp)
end_time = time.time()
elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / total_duration
logging.info(f"Elapsed time: {elapsed_seconds:.3f} s")
logging.info(f"Wave duration: {total_duration:.3f} s")
logging.info(
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
)
save_results(res_dir=res_dir, test_set_name=test_set, results=results)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,319 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads ONNX exported models and uses them to decode the test sets.
We use the pre-trained model from
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-epoch-30-avg-9.pt"
cd exp
ln -s pretrained-epoch-30-avg-9.pt epoch-9999.pt
popd
2. Export the model to ONNX
./pruned_transducer_stateless7/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 9999 \
--avg 1 \
--exp-dir $repo/exp/
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-9999-avg-1.onnx
- decoder-epoch-9999-avg-1.onnx
- joiner-epoch-9999-avg-1.onnx
2. Run this file
./pruned_transducer_stateless7/onnx_decode.py \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 600 \
--encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \
"""
import argparse
import logging
import time
from pathlib import Path
from typing import List, Tuple
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from onnx_pretrained import greedy_search, OnnxModel
from icefall.utils import setup_logger, store_transcripts, write_error_stats
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder onnx model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder onnx model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner onnx model. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="Valid values are greedy_search and modified_beam_search",
)
return parser
def decode_one_batch(
model: OnnxModel, sp: spm.SentencePieceProcessor, batch: dict
) -> List[List[str]]:
"""Decode one batch and return the result.
Currently it only greedy_search is supported.
Args:
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
Returns:
Return the decoded results for each utterance.
"""
feature = batch["inputs"]
assert feature.ndim == 3
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(dtype=torch.int64)
encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens)
hyps = greedy_search(
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
)
hyps = [sp.decode(h).split() for h in hyps]
return hyps
def decode_dataset(
dl: torch.utils.data.DataLoader,
model: nn.Module,
sp: spm.SentencePieceProcessor,
) -> Tuple[List[Tuple[str, List[str], List[str]]], float]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
model:
The neural model.
sp:
The BPE model.
Returns:
- A list of tuples. Each tuple contains three elements:
- cut_id,
- reference transcript,
- predicted result.
- The total duration (in seconds) of the dataset.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
log_interval = 10
total_duration = 0
results = []
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]])
hyps = decode_one_batch(model=model, sp=sp, batch=batch)
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results.extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results, total_duration
def save_results(
res_dir: Path,
test_set_name: str,
results: List[Tuple[str, List[str], List[str]]],
):
recog_path = res_dir / f"recogs-{test_set_name}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = res_dir / f"errs-{test_set_name}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True)
logging.info("Wrote detailed error stats to {}".format(errs_filename))
errs_info = res_dir / f"wer-summary-{test_set_name}.txt"
with open(errs_info, "w") as f:
print("WER", file=f)
print(wer, file=f)
s = "\nFor {}, WER is {}:\n".format(test_set_name, wer)
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
assert (
args.decoding_method == "greedy_search"
), "Only supports greedy_search currently."
res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}"
setup_logger(f"{res_dir}/log-decode")
logging.info("Decoding started")
device = torch.device("cpu")
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
blank_id = sp.piece_to_id("<blk>")
assert blank_id == 0, blank_id
logging.info(vars(args))
logging.info("About to create model")
model = OnnxModel(
encoder_model_filename=args.encoder_model_filename,
decoder_model_filename=args.decoder_model_filename,
joiner_model_filename=args.joiner_model_filename,
)
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
start_time = time.time()
results, total_duration = decode_dataset(dl=test_dl, model=model, sp=sp)
end_time = time.time()
elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / total_duration
logging.info(f"Elapsed time: {elapsed_seconds:.3f} s")
logging.info(f"Wave duration: {total_duration:.3f} s")
logging.info(
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
)
save_results(res_dir=res_dir, test_set_name=test_set, results=results)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -24,7 +24,7 @@ import torch
from torch import Tensor, nn
from transformer import Transformer
from icefall.utils import make_pad_mask, subsequent_chunk_mask
from icefall.utils import is_jit_tracing, make_pad_mask, subsequent_chunk_mask
class Conformer(Transformer):
@ -154,7 +154,8 @@ class Conformer(Transformer):
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 1) >> 1) - 1) >> 1
assert x.size(0) == lengths.max().item()
if not is_jit_tracing():
assert x.size(0) == lengths.max().item()
src_key_padding_mask = make_pad_mask(lengths)
@ -768,6 +769,14 @@ class RelPositionalEncoding(torch.nn.Module):
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
if is_jit_tracing():
# 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e.,
# It assumes that the maximum input won't have more than
# 10k frames.
#
# TODO(fangjun): Use torch.jit.script() for this module
max_len = 10000
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate)
@ -975,22 +984,34 @@ class RelPositionMultiheadAttention(nn.Module):
the key, while time1 is for the query).
"""
(batch_size, num_heads, time1, n) = x.shape
time2 = time1 + left_context
if not is_jit_tracing():
assert (
n == left_context + 2 * time1 - 1
), f"{n} == {left_context} + 2 * {time1} - 1"
assert (
n == left_context + 2 * time1 - 1
), f"{n} == {left_context} + 2 * {time1} - 1"
if is_jit_tracing():
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
cols = torch.arange(time2)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
indexes = rows + cols
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)
time1_stride = x.stride(2)
n_stride = x.stride(3)
return x.as_strided(
(batch_size, num_heads, time1, time2),
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
storage_offset=n_stride * (time1 - 1),
)
x = x.reshape(-1, n)
x = torch.gather(x, dim=1, index=indexes)
x = x.reshape(batch_size, num_heads, time1, time2)
return x
else:
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)
time1_stride = x.stride(2)
n_stride = x.stride(3)
return x.as_strided(
(batch_size, num_heads, time1, time2),
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
storage_offset=n_stride * (time1 - 1),
)
def multi_head_attention_forward(
self,
@ -1061,13 +1082,16 @@ class RelPositionMultiheadAttention(nn.Module):
"""
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == embed_dim_to_check
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
if not is_jit_tracing():
assert embed_dim == embed_dim_to_check
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
head_dim = embed_dim // num_heads
assert (
head_dim * num_heads == embed_dim
), "embed_dim must be divisible by num_heads"
if not is_jit_tracing():
assert (
head_dim * num_heads == embed_dim
), "embed_dim must be divisible by num_heads"
scaling = float(head_dim) ** -0.5
if torch.equal(query, key) and torch.equal(key, value):
@ -1181,7 +1205,8 @@ class RelPositionMultiheadAttention(nn.Module):
q = q.transpose(0, 1) # (batch, time1, head, d_k)
pos_emb_bsz = pos_emb.size(0)
assert pos_emb_bsz in (1, bsz) # actually it is 1
if not is_jit_tracing():
assert pos_emb_bsz in (1, bsz) # actually it is 1
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
# (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
@ -1212,11 +1237,12 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
assert list(attn_output_weights.size()) == [
bsz * num_heads,
tgt_len,
src_len,
]
if not is_jit_tracing():
assert list(attn_output_weights.size()) == [
bsz * num_heads,
tgt_len,
src_len,
]
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
@ -1265,7 +1291,10 @@ class RelPositionMultiheadAttention(nn.Module):
)
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
if not is_jit_tracing():
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = (
attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
)