Add CTC HLG decoding for zipformer (#1287)

This commit is contained in:
Fangjun Kuang 2023-10-02 14:00:06 +08:00 committed by GitHub
parent f14b673408
commit 109354b6b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1545 additions and 21 deletions

View File

@ -10,7 +10,57 @@ log() {
pushd egs/librispeech/ASR
# repo_url=https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500
repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02
log "Downloading pre-trained model from $repo_url"
git lfs install
git clone $repo_url
repo=$(basename $repo_url)
log "Display test files"
tree $repo/
ls -lh $repo/test_wavs/*.wav
log "CTC greedy search"
./zipformer/onnx_pretrained_ctc.py \
--nn-model $repo/model.onnx \
--tokens $repo/tokens.txt \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/2.wav
log "CTC H decoding"
./zipformer/onnx_pretrained_ctc_H.py \
--nn-model $repo/model.onnx \
--tokens $repo/tokens.txt \
--H $repo/H.fst \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/2.wav
log "CTC HL decoding"
./zipformer/onnx_pretrained_ctc_HL.py \
--nn-model $repo/model.onnx \
--words $repo/words.txt \
--HL $repo/HL.fst \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/2.wav
log "CTC HLG decoding"
./zipformer/onnx_pretrained_ctc_HLG.py \
--nn-model $repo/model.onnx \
--words $repo/words.txt \
--HLG $repo/HLG.fst \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/2.wav
rm -rf $repo
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-conformer-ctc-jit-bpe-500-2021-11-09
log "Downloading pre-trained model from $repo_url"
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
@ -128,7 +178,9 @@ repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "exp/pretrained.pt"
git lfs pull --include "data/lm/G_3_gram_char.fst.txt"
git lfs pull --include "data/lang_char/H.fst"
git lfs pull --include "data/lang_char/HL.fst"
git lfs pull --include "data/lang_char/HLG.fst"
popd
@ -153,10 +205,6 @@ popd
ls -lh $repo/exp
log "Generating H.fst, HL.fst"
./local/prepare_lang_fst.py --lang-dir $repo/data/lang_char --ngram-G $repo/data/lm/G_3_gram_char.fst.txt
ls -lh $repo/data/lang_char
log "Decoding with H on CPU with OpenFst"

View File

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
name: run-pre-trained-conformer-ctc
name: run-pre-trained-ctc
on:
push:
@ -31,12 +31,12 @@ on:
default: 'y'
concurrency:
group: run_pre_trained_conformer_ctc-${{ github.ref }}
group: run_pre_trained_ctc-${{ github.ref }}
cancel-in-progress: true
jobs:
run_pre_trained_conformer_ctc:
if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event.inputs.test-run == 'y'
run_pre_trained_ctc:
if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event.inputs.test-run == 'y' || github.event.label.name == 'ctc'
runs-on: ${{ matrix.os }}
strategy:
matrix:
@ -84,4 +84,4 @@ jobs:
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
.github/scripts/run-pre-trained-conformer-ctc.sh
.github/scripts/run-pre-trained-ctc.sh

View File

@ -145,7 +145,7 @@ def decode(
decoder.decode(decodable)
if not decoder.reached_final():
print(f"failed to decode {filename}")
logging.info(f"failed to decode {filename}")
return [""]
ok, best_path = decoder.get_best_path()
@ -157,7 +157,7 @@ def decode(
total_weight,
) = kaldifst.get_linear_symbol_sequence(best_path)
if not ok:
print(f"failed to get linear symbol sequence for {filename}")
logging.info(f"failed to get linear symbol sequence for {filename}")
return [""]
# tokens are incremented during graph construction

View File

@ -132,8 +132,8 @@ def decode(
contains output from log_softmax.
HL:
The HL graph.
word2token:
A map mapping token ID to word string.
id2word:
A map mapping word ID to word string.
Returns:
Return a list of decoded words.
"""
@ -145,7 +145,7 @@ def decode(
decoder.decode(decodable)
if not decoder.reached_final():
print(f"failed to decode {filename}")
logging.info(f"failed to decode {filename}")
return [""]
ok, best_path = decoder.get_best_path()
@ -157,7 +157,7 @@ def decode(
total_weight,
) = kaldifst.get_linear_symbol_sequence(best_path)
if not ok:
print(f"failed to get linear symbol sequence for {filename}")
logging.info(f"failed to get linear symbol sequence for {filename}")
return [""]
# are shifted by 1 during graph construction

View File

@ -131,8 +131,8 @@ def decode(
contains output from log_softmax.
HLG:
The HLG graph.
word2token:
A map mapping token ID to word string.
id2word:
A map mapping word ID to word string.
Returns:
Return a list of decoded words.
"""
@ -144,7 +144,7 @@ def decode(
decoder.decode(decodable)
if not decoder.reached_final():
print(f"failed to decode {filename}")
logging.info(f"failed to decode {filename}")
return [""]
ok, best_path = decoder.get_best_path()
@ -156,7 +156,7 @@ def decode(
total_weight,
) = kaldifst.get_linear_symbol_sequence(best_path)
if not ok:
print(f"failed to get linear symbol sequence for {filename}")
logging.info(f"failed to get linear symbol sequence for {filename}")
return [""]
# are shifted by 1 during graph construction

View File

@ -0,0 +1,436 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
"""
This script exports a CTC model from PyTorch to ONNX.
Note that the model is trained using both transducer and CTC loss. This script
exports only the CTC head.
We use the pre-trained model from
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
2. Export the model to ONNX
./zipformer/export-onnx-ctc.py \
--use-transducer 0 \
--use-ctc 1 \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--num-encoder-layers "2,2,3,4,3,2" \
--downsampling-factor "1,2,4,8,4,2" \
--feedforward-dim "512,768,1024,1536,1024,768" \
--num-heads "4,4,4,8,4,4" \
--encoder-dim "192,256,384,512,384,256" \
--query-head-dim 32 \
--value-head-dim 12 \
--pos-head-dim 4 \
--pos-dim 48 \
--encoder-unmasked-dim "192,192,256,256,256,192" \
--cnn-module-kernel "31,31,15,15,15,31" \
--decoder-dim 512 \
--joiner-dim 512 \
--causal False \
--chunk-size 16 \
--left-context-frames 128
It will generate the following 2 files inside $repo/exp:
- model.onnx
- model.int8.onnx
See ./onnx_pretrained_ctc.py for how to
use the exported ONNX models.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict, Tuple
import k2
import onnx
import torch
import torch.nn as nn
from decoder import Decoder
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_model, get_params
from zipformer import Zipformer2
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import make_pad_mask, num_tokens, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--tokens",
type=str,
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
onnx.save(model, filename)
class OnnxModel(nn.Module):
"""A wrapper for encoder_embed, Zipformer, and ctc_output layer"""
def __init__(
self,
encoder: Zipformer2,
encoder_embed: nn.Module,
ctc_output: nn.Module,
):
"""
Args:
encoder:
A Zipformer encoder.
encoder_embed:
The first downsampling layer for zipformer.
"""
super().__init__()
self.encoder = encoder
self.encoder_embed = encoder_embed
self.ctc_output = ctc_output
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Please see the help information of Zipformer.forward
Args:
x:
A 3-D tensor of shape (N, T, C)
x_lens:
A 1-D tensor of shape (N,). Its dtype is torch.int64
Returns:
Return a tuple containing:
- log_probs, a 3-D tensor of shape (N, T', vocab_size)
- log_probs_len, a 1-D int64 tensor of shape (N,)
"""
x, x_lens = self.encoder_embed(x, x_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2)
encoder_out, log_probs_len = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2)
log_probs = self.ctc_output(encoder_out)
return log_probs, log_probs_len
def export_ctc_model_onnx(
model: OnnxModel,
filename: str,
opset_version: int = 11,
) -> None:
"""Export the given model to ONNX format.
The exported model has two inputs:
- x, a tensor of shape (N, T, C); dtype is torch.float32
- x_lens, a tensor of shape (N,); dtype is torch.int64
and it has two outputs:
- log_probs, a tensor of shape (N, T', joiner_dim)
- log_probs_len, a tensor of shape (N,)
Args:
model:
The input model
filename:
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
x = torch.zeros(1, 100, 80, dtype=torch.float32)
x_lens = torch.tensor([100], dtype=torch.int64)
model = torch.jit.trace(model, (x, x_lens))
torch.onnx.export(
model,
(x, x_lens),
filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "x_lens"],
output_names=["log_probs", "log_probs_len"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"log_probs": {0: "N", 1: "T"},
"log_probs_len": {0: "N"},
},
)
meta_data = {
"model_type": "zipformer2_ctc",
"version": "1",
"model_author": "k2-fsa",
"comment": "non-streaming zipformer2 CTC",
}
logging.info(f"meta_data: {meta_data}")
add_meta_data(filename=filename, meta_data=meta_data)
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
logging.info("About to create model")
model = get_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
elif params.avg == 1:
load_checkpoint(
f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False
)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
),
strict=False,
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
),
strict=False,
)
model.to("cpu")
model.eval()
convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
model = OnnxModel(
encoder=model.encoder,
encoder_embed=model.encoder_embed,
ctc_output=model.ctc_output,
)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"num parameters: {num_param}")
opset_version = 13
logging.info("Exporting ctc model")
filename = params.exp_dir / f"model.onnx"
export_ctc_model_onnx(
model,
filename,
opset_version=opset_version,
)
logging.info(f"Exported to {filename}")
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
logging.info("Generate int8 quantization models")
filename_int8 = params.exp_dir / f"model.int8.onnx"
quantize_dynamic(
model_input=filename,
model_output=filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,213 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
"""
This script loads ONNX models and uses them to decode waves.
We use the pre-trained model from
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13
as an example to show how to use this file.
1. Please follow ./export-onnx-ctc.py to get the onnx model.
2. Run this file
./zipformer/onnx_pretrained_ctc.py \
--nn-model /path/to/model.onnx \
--tokens /path/to/data/lang_bpe_500/tokens.txt \
1089-134686-0001.wav \
1221-135766-0001.wav \
1221-135766-0002.wav
"""
import argparse
import logging
import math
from typing import List, Tuple
import k2
import kaldifeat
import onnxruntime as ort
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--nn-model",
type=str,
required=True,
help="Path to the onnx model. ",
)
parser.add_argument(
"--tokens",
type=str,
help="""Path to tokens.txt.""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
return parser
class OnnxModel:
def __init__(
self,
nn_model: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.init_model(nn_model)
def init_model(self, nn_model: str):
self.model = ort.InferenceSession(
nn_model,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
meta = self.model.get_modelmeta().custom_metadata_map
print(meta)
def __call__(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
A 3-D float tensor of shape (N, T, C)
x_lens:
A 1-D int64 tensor of shape (N,)
Returns:
Return a tuple containing:
- A float tensor containing log_probs of shape (N, T, C)
- A int64 tensor containing log_probs_len of shape (N)
"""
out = self.model.run(
[
self.model.get_outputs()[0].name,
self.model.get_outputs()[1].name,
],
{
self.model.get_inputs()[0].name: x.numpy(),
self.model.get_inputs()[1].name: x_lens.numpy(),
},
)
return torch.from_numpy(out[0]), torch.from_numpy(out[1])
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
model = OnnxModel(
nn_model=args.nn_model,
)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files,
expected_sample_rate=args.sample_rate,
)
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10),
)
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
log_probs, log_probs_len = model(features, feature_lengths)
token_table = k2.SymbolTable.from_file(args.tokens)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
blank_id = 0
s = "\n"
for i in range(log_probs.size(0)):
# greedy search
indexes = log_probs[i, : log_probs_len[i]].argmax(dim=-1)
token_ids = torch.unique_consecutive(indexes)
token_ids = token_ids[token_ids != blank_id]
words = token_ids_to_words(token_ids.tolist())
s += f"{args.sound_files[i]}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,277 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
"""
This script loads ONNX models and uses them to decode waves.
We use the pre-trained model from
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13
as an example to show how to use this file.
1. Please follow ./export-onnx-ctc.py to get the onnx model.
2. Run this file
./zipformer/onnx_pretrained_ctc_H.py \
--nn-model /path/to/model.onnx \
--tokens /path/to/data/lang_bpe_500/tokens.txt \
--H /path/to/H.fst \
1089-134686-0001.wav \
1221-135766-0001.wav \
1221-135766-0002.wav
You can find exported ONNX models at
https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02
"""
import argparse
import logging
import math
from typing import List, Tuple
import k2
import kaldifeat
from typing import Dict
import kaldifst
import onnxruntime as ort
import torch
import torchaudio
from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions
from torch.nn.utils.rnn import pad_sequence
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--nn-model",
type=str,
required=True,
help="Path to the onnx model. ",
)
parser.add_argument(
"--tokens",
type=str,
help="""Path to tokens.txt.""",
)
parser.add_argument(
"--H",
type=str,
help="""Path to H.fst.""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
return parser
class OnnxModel:
def __init__(
self,
nn_model: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.init_model(nn_model)
def init_model(self, nn_model: str):
self.model = ort.InferenceSession(
nn_model,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
meta = self.model.get_modelmeta().custom_metadata_map
print(meta)
def __call__(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
A 3-D float tensor of shape (N, T, C)
x_lens:
A 1-D int64 tensor of shape (N,)
Returns:
Return a tuple containing:
- A float tensor containing log_probs of shape (N, T, C)
- A int64 tensor containing log_probs_len of shape (N)
"""
out = self.model.run(
[
self.model.get_outputs()[0].name,
self.model.get_outputs()[1].name,
],
{
self.model.get_inputs()[0].name: x.numpy(),
self.model.get_inputs()[1].name: x_lens.numpy(),
},
)
return torch.from_numpy(out[0]), torch.from_numpy(out[1])
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
def decode(
filename: str,
log_probs: torch.Tensor,
H: kaldifst,
id2token: Dict[int, str],
) -> List[str]:
"""
Args:
filename:
Path to the filename for decoding. Used for debugging.
log_probs:
A 2-D float32 tensor of shape (num_frames, vocab_size). It
contains output from log_softmax.
H:
The H graph.
id2word:
A map mapping token ID to word string.
Returns:
Return a list of decoded words.
"""
logging.info(f"{filename}, {log_probs.shape}")
decodable = DecodableCtc(log_probs.cpu())
decoder_opts = FasterDecoderOptions(max_active=3000)
decoder = FasterDecoder(H, decoder_opts)
decoder.decode(decodable)
if not decoder.reached_final():
logging.info(f"failed to decode {filename}")
return [""]
ok, best_path = decoder.get_best_path()
(
ok,
isymbols_out,
osymbols_out,
total_weight,
) = kaldifst.get_linear_symbol_sequence(best_path)
if not ok:
logging.info(f"failed to get linear symbol sequence for {filename}")
return [""]
# tokens are incremented during graph construction
# are shifted by 1 during graph construction
hyps = [id2token[i - 1] for i in osymbols_out if i != 1]
hyps = "".join(hyps).split("\u2581") # unicode codepoint of ▁
return hyps
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
model = OnnxModel(
nn_model=args.nn_model,
)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
logging.info(f"Loading H from {args.H}")
H = kaldifst.StdVectorFst.read(args.H)
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files,
expected_sample_rate=args.sample_rate,
)
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10),
)
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
log_probs, log_probs_len = model(features, feature_lengths)
token_table = k2.SymbolTable.from_file(args.tokens)
hyps = []
for i in range(log_probs.shape[0]):
hyp = decode(
filename=args.sound_files[i],
log_probs=log_probs[i, : log_probs_len[i]],
H=H,
id2token=token_table,
)
hyps.append(hyp)
s = "\n"
for filename, hyp in zip(args.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,275 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
"""
This script loads ONNX models and uses them to decode waves.
We use the pre-trained model from
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13
as an example to show how to use this file.
1. Please follow ./export-onnx-ctc.py to get the onnx model.
2. Run this file
./zipformer/onnx_pretrained_ctc_HL.py \
--nn-model /path/to/model.onnx \
--words /path/to/data/lang_bpe_500/words.txt \
--HL /path/to/HL.fst \
1089-134686-0001.wav \
1221-135766-0001.wav \
1221-135766-0002.wav
You can find exported ONNX models at
https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02
"""
import argparse
import logging
import math
from typing import List, Tuple
import k2
import kaldifeat
from typing import Dict
import kaldifst
import onnxruntime as ort
import torch
import torchaudio
from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions
from torch.nn.utils.rnn import pad_sequence
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--nn-model",
type=str,
required=True,
help="Path to the onnx model. ",
)
parser.add_argument(
"--words",
type=str,
help="""Path to words.txt.""",
)
parser.add_argument(
"--HL",
type=str,
help="""Path to HL.fst.""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
return parser
class OnnxModel:
def __init__(
self,
nn_model: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.init_model(nn_model)
def init_model(self, nn_model: str):
self.model = ort.InferenceSession(
nn_model,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
meta = self.model.get_modelmeta().custom_metadata_map
print(meta)
def __call__(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
A 3-D float tensor of shape (N, T, C)
x_lens:
A 1-D int64 tensor of shape (N,)
Returns:
Return a tuple containing:
- A float tensor containing log_probs of shape (N, T, C)
- A int64 tensor containing log_probs_len of shape (N)
"""
out = self.model.run(
[
self.model.get_outputs()[0].name,
self.model.get_outputs()[1].name,
],
{
self.model.get_inputs()[0].name: x.numpy(),
self.model.get_inputs()[1].name: x_lens.numpy(),
},
)
return torch.from_numpy(out[0]), torch.from_numpy(out[1])
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
def decode(
filename: str,
log_probs: torch.Tensor,
HL: kaldifst,
id2word: Dict[int, str],
) -> List[str]:
"""
Args:
filename:
Path to the filename for decoding. Used for debugging.
log_probs:
A 2-D float32 tensor of shape (num_frames, vocab_size). It
contains output from log_softmax.
HL:
The HL graph.
id2word:
A map mapping word ID to word string.
Returns:
Return a list of decoded words.
"""
logging.info(f"{filename}, {log_probs.shape}")
decodable = DecodableCtc(log_probs.cpu())
decoder_opts = FasterDecoderOptions(max_active=3000)
decoder = FasterDecoder(HL, decoder_opts)
decoder.decode(decodable)
if not decoder.reached_final():
logging.info(f"failed to decode {filename}")
return [""]
ok, best_path = decoder.get_best_path()
(
ok,
isymbols_out,
osymbols_out,
total_weight,
) = kaldifst.get_linear_symbol_sequence(best_path)
if not ok:
logging.info(f"failed to get linear symbol sequence for {filename}")
return [""]
# are shifted by 1 during graph construction
hyps = [id2word[i] for i in osymbols_out]
return hyps
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
model = OnnxModel(
nn_model=args.nn_model,
)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
logging.info(f"Loading HL from {args.HL}")
HL = kaldifst.StdVectorFst.read(args.HL)
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files,
expected_sample_rate=args.sample_rate,
)
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10),
)
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
log_probs, log_probs_len = model(features, feature_lengths)
word_table = k2.SymbolTable.from_file(args.words)
hyps = []
for i in range(log_probs.shape[0]):
hyp = decode(
filename=args.sound_files[i],
log_probs=log_probs[i, : log_probs_len[i]],
HL=HL,
id2word=word_table,
)
hyps.append(hyp)
s = "\n"
for filename, hyp in zip(args.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,275 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
"""
This script loads ONNX models and uses them to decode waves.
We use the pre-trained model from
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13
as an example to show how to use this file.
1. Please follow ./export-onnx-ctc.py to get the onnx model.
2. Run this file
./zipformer/onnx_pretrained_ctc_HLG.py \
--nn-model /path/to/model.onnx \
--words /path/to/data/lang_bpe_500/words.txt \
--HLG /path/to/HLG.fst \
1089-134686-0001.wav \
1221-135766-0001.wav \
1221-135766-0002.wav
You can find exported ONNX models at
https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02
"""
import argparse
import logging
import math
from typing import List, Tuple
import k2
import kaldifeat
from typing import Dict
import kaldifst
import onnxruntime as ort
import torch
import torchaudio
from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions
from torch.nn.utils.rnn import pad_sequence
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--nn-model",
type=str,
required=True,
help="Path to the onnx model. ",
)
parser.add_argument(
"--words",
type=str,
help="""Path to words.txt.""",
)
parser.add_argument(
"--HLG",
type=str,
help="""Path to HLG.fst.""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
return parser
class OnnxModel:
def __init__(
self,
nn_model: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.init_model(nn_model)
def init_model(self, nn_model: str):
self.model = ort.InferenceSession(
nn_model,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
meta = self.model.get_modelmeta().custom_metadata_map
print(meta)
def __call__(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
A 3-D float tensor of shape (N, T, C)
x_lens:
A 1-D int64 tensor of shape (N,)
Returns:
Return a tuple containing:
- A float tensor containing log_probs of shape (N, T, C)
- A int64 tensor containing log_probs_len of shape (N)
"""
out = self.model.run(
[
self.model.get_outputs()[0].name,
self.model.get_outputs()[1].name,
],
{
self.model.get_inputs()[0].name: x.numpy(),
self.model.get_inputs()[1].name: x_lens.numpy(),
},
)
return torch.from_numpy(out[0]), torch.from_numpy(out[1])
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
def decode(
filename: str,
log_probs: torch.Tensor,
HLG: kaldifst,
id2word: Dict[int, str],
) -> List[str]:
"""
Args:
filename:
Path to the filename for decoding. Used for debugging.
log_probs:
A 2-D float32 tensor of shape (num_frames, vocab_size). It
contains output from log_softmax.
HLG:
The HLG graph.
id2word:
A map mapping word ID to word string.
Returns:
Return a list of decoded words.
"""
logging.info(f"{filename}, {log_probs.shape}")
decodable = DecodableCtc(log_probs.cpu())
decoder_opts = FasterDecoderOptions(max_active=3000)
decoder = FasterDecoder(HLG, decoder_opts)
decoder.decode(decodable)
if not decoder.reached_final():
logging.info(f"failed to decode {filename}")
return [""]
ok, best_path = decoder.get_best_path()
(
ok,
isymbols_out,
osymbols_out,
total_weight,
) = kaldifst.get_linear_symbol_sequence(best_path)
if not ok:
logging.info(f"failed to get linear symbol sequence for {filename}")
return [""]
# are shifted by 1 during graph construction
hyps = [id2word[i] for i in osymbols_out]
return hyps
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
model = OnnxModel(
nn_model=args.nn_model,
)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
logging.info(f"Loading HLG from {args.HLG}")
HLG = kaldifst.StdVectorFst.read(args.HLG)
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files,
expected_sample_rate=args.sample_rate,
)
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10),
)
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
log_probs, log_probs_len = model(features, feature_lengths)
word_table = k2.SymbolTable.from_file(args.words)
hyps = []
for i in range(log_probs.shape[0]):
hyp = decode(
filename=args.sound_files[i],
log_probs=log_probs[i, : log_probs_len[i]],
HLG=HLG,
id2word=word_table,
)
hyps.append(hyp)
s = "\n"
for filename, hyp in zip(args.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()