mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Replace some with soft links
This commit is contained in:
parent
30a4dd2f95
commit
2ad176321a
@ -1,436 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
This script exports a CTC model from PyTorch to ONNX.
|
||||
|
||||
Note that the model is trained using both transducer and CTC loss. This script
|
||||
exports only the CTC head.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export the model to ONNX
|
||||
|
||||
./zipformer/export-onnx-ctc.py \
|
||||
--use-transducer 0 \
|
||||
--use-ctc 1 \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp \
|
||||
--num-encoder-layers "2,2,3,4,3,2" \
|
||||
--downsampling-factor "1,2,4,8,4,2" \
|
||||
--feedforward-dim "512,768,1024,1536,1024,768" \
|
||||
--num-heads "4,4,4,8,4,4" \
|
||||
--encoder-dim "192,256,384,512,384,256" \
|
||||
--query-head-dim 32 \
|
||||
--value-head-dim 12 \
|
||||
--pos-head-dim 4 \
|
||||
--pos-dim 48 \
|
||||
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
||||
--cnn-module-kernel "31,31,15,15,15,31" \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512 \
|
||||
--causal False \
|
||||
--chunk-size 16 \
|
||||
--left-context-frames 128
|
||||
|
||||
It will generate the following 2 files inside $repo/exp:
|
||||
|
||||
- model.onnx
|
||||
- model.int8.onnx
|
||||
|
||||
See ./onnx_pretrained_ctc.py for how to
|
||||
use the exported ONNX models.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import k2
|
||||
import onnx
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from decoder import Decoder
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
from zipformer import Zipformer2
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import make_pad_mask, num_tokens, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipformer/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/tokens.txt",
|
||||
help="Path to the tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||
"""Add meta data to an ONNX model. It is changed in-place.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the ONNX model to be changed.
|
||||
meta_data:
|
||||
Key-value pairs.
|
||||
"""
|
||||
model = onnx.load(filename)
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = value
|
||||
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
class OnnxModel(nn.Module):
|
||||
"""A wrapper for encoder_embed, Zipformer, and ctc_output layer"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: Zipformer2,
|
||||
encoder_embed: nn.Module,
|
||||
ctc_output: nn.Module,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
encoder:
|
||||
A Zipformer encoder.
|
||||
encoder_embed:
|
||||
The first downsampling layer for zipformer.
|
||||
"""
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.encoder_embed = encoder_embed
|
||||
self.ctc_output = ctc_output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Please see the help information of Zipformer.forward
|
||||
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). Its dtype is torch.int64
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- log_probs, a 3-D tensor of shape (N, T', vocab_size)
|
||||
- log_probs_len, a 1-D int64 tensor of shape (N,)
|
||||
"""
|
||||
x, x_lens = self.encoder_embed(x, x_lens)
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2)
|
||||
encoder_out, log_probs_len = self.encoder(x, x_lens, src_key_padding_mask)
|
||||
encoder_out = encoder_out.permute(1, 0, 2)
|
||||
log_probs = self.ctc_output(encoder_out)
|
||||
|
||||
return log_probs, log_probs_len
|
||||
|
||||
|
||||
def export_ctc_model_onnx(
|
||||
model: OnnxModel,
|
||||
filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the given model to ONNX format.
|
||||
The exported model has two inputs:
|
||||
|
||||
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
||||
- x_lens, a tensor of shape (N,); dtype is torch.int64
|
||||
|
||||
and it has two outputs:
|
||||
|
||||
- log_probs, a tensor of shape (N, T', joiner_dim)
|
||||
- log_probs_len, a tensor of shape (N,)
|
||||
|
||||
Args:
|
||||
model:
|
||||
The input model
|
||||
filename:
|
||||
The filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||
|
||||
model = torch.jit.trace(model, (x, x_lens))
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(x, x_lens),
|
||||
filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x", "x_lens"],
|
||||
output_names=["log_probs", "log_probs_len"],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "T"},
|
||||
"x_lens": {0: "N"},
|
||||
"log_probs": {0: "N", 1: "T"},
|
||||
"log_probs_len": {0: "N"},
|
||||
},
|
||||
)
|
||||
|
||||
meta_data = {
|
||||
"model_type": "zipformer2_ctc",
|
||||
"version": "1",
|
||||
"model_author": "k2-fsa",
|
||||
"comment": "non-streaming zipformer2 CTC",
|
||||
}
|
||||
logging.info(f"meta_data: {meta_data}")
|
||||
|
||||
add_meta_data(filename=filename, meta_data=meta_data)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||
params.blank_id = token_table["<blk>"]
|
||||
params.vocab_size = num_tokens(token_table) + 1
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device), strict=False
|
||||
)
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False
|
||||
)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device), strict=False
|
||||
)
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
),
|
||||
strict=False,
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
),
|
||||
strict=False,
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
|
||||
|
||||
model = OnnxModel(
|
||||
encoder=model.encoder,
|
||||
encoder_embed=model.encoder_embed,
|
||||
ctc_output=model.ctc_output,
|
||||
)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"num parameters: {num_param}")
|
||||
|
||||
opset_version = 13
|
||||
|
||||
logging.info("Exporting ctc model")
|
||||
filename = params.exp_dir / f"model.onnx"
|
||||
export_ctc_model_onnx(
|
||||
model,
|
||||
filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported to {filename}")
|
||||
|
||||
# Generate int8 quantization models
|
||||
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||
|
||||
logging.info("Generate int8 quantization models")
|
||||
|
||||
filename_int8 = params.exp_dir / f"model.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=filename,
|
||||
model_output=filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
1
egs/gigaspeech/ASR/zipformer/export-onnx-ctc.py
Symbolic link
1
egs/gigaspeech/ASR/zipformer/export-onnx-ctc.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/export-onnx-ctc.py
|
||||
@ -1,775 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang)
|
||||
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
|
||||
|
||||
"""
|
||||
This script exports a transducer model from PyTorch to ONNX.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export the model to ONNX
|
||||
|
||||
./zipformer/export-onnx-streaming.py \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp \
|
||||
--num-encoder-layers "2,2,3,4,3,2" \
|
||||
--downsampling-factor "1,2,4,8,4,2" \
|
||||
--feedforward-dim "512,768,1024,1536,1024,768" \
|
||||
--num-heads "4,4,4,8,4,4" \
|
||||
--encoder-dim "192,256,384,512,384,256" \
|
||||
--query-head-dim 32 \
|
||||
--value-head-dim 12 \
|
||||
--pos-head-dim 4 \
|
||||
--pos-dim 48 \
|
||||
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
||||
--cnn-module-kernel "31,31,15,15,15,31" \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512 \
|
||||
--causal True \
|
||||
--chunk-size 16 \
|
||||
--left-context-frames 64
|
||||
|
||||
The --chunk-size in training is "16,32,64,-1", so we select one of them
|
||||
(excluding -1) during streaming export. The same applies to `--left-context`,
|
||||
whose value is "64,128,256,-1".
|
||||
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
- encoder-epoch-99-avg-1-chunk-16-left-64.onnx
|
||||
- decoder-epoch-99-avg-1-chunk-16-left-64.onnx
|
||||
- joiner-epoch-99-avg-1-chunk-16-left-64.onnx
|
||||
|
||||
See ./onnx_pretrained-streaming.py for how to use the exported ONNX models.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import k2
|
||||
import onnx
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from decoder import Decoder
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
from zipformer import Zipformer2
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import num_tokens, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipformer/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/tokens.txt",
|
||||
help="Path to the tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||
"""Add meta data to an ONNX model. It is changed in-place.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the ONNX model to be changed.
|
||||
meta_data:
|
||||
Key-value pairs.
|
||||
"""
|
||||
model = onnx.load(filename)
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = value
|
||||
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
class OnnxEncoder(nn.Module):
|
||||
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
|
||||
|
||||
def __init__(
|
||||
self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
encoder:
|
||||
A Zipformer encoder.
|
||||
encoder_proj:
|
||||
The projection layer for encoder from the joiner.
|
||||
"""
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.encoder_embed = encoder_embed
|
||||
self.encoder_proj = encoder_proj
|
||||
self.chunk_size = encoder.chunk_size[0]
|
||||
self.left_context_len = encoder.left_context_frames[0]
|
||||
self.pad_length = 7 + 2 * 3
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
states: List[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
||||
N = x.size(0)
|
||||
T = self.chunk_size * 2 + self.pad_length
|
||||
x_lens = torch.tensor([T] * N, device=x.device)
|
||||
left_context_len = self.left_context_len
|
||||
|
||||
cached_embed_left_pad = states[-2]
|
||||
x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
|
||||
x=x,
|
||||
x_lens=x_lens,
|
||||
cached_left_pad=cached_embed_left_pad,
|
||||
)
|
||||
assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size)
|
||||
|
||||
src_key_padding_mask = torch.zeros(N, self.chunk_size, dtype=torch.bool)
|
||||
|
||||
# processed_mask is used to mask out initial states
|
||||
processed_mask = torch.arange(left_context_len, device=x.device).expand(
|
||||
x.size(0), left_context_len
|
||||
)
|
||||
processed_lens = states[-1] # (batch,)
|
||||
# (batch, left_context_size)
|
||||
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
|
||||
# Update processed lengths
|
||||
new_processed_lens = processed_lens + x_lens
|
||||
# (batch, left_context_size + chunk_size)
|
||||
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
|
||||
|
||||
x = x.permute(1, 0, 2)
|
||||
encoder_states = states[:-2]
|
||||
logging.info(f"len_encoder_states={len(encoder_states)}")
|
||||
(
|
||||
encoder_out,
|
||||
encoder_out_lens,
|
||||
new_encoder_states,
|
||||
) = self.encoder.streaming_forward(
|
||||
x=x,
|
||||
x_lens=x_lens,
|
||||
states=encoder_states,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
encoder_out = encoder_out.permute(1, 0, 2)
|
||||
encoder_out = self.encoder_proj(encoder_out)
|
||||
# Now encoder_out is of shape (N, T, joiner_dim)
|
||||
|
||||
new_states = new_encoder_states + [
|
||||
new_cached_embed_left_pad,
|
||||
new_processed_lens,
|
||||
]
|
||||
|
||||
return encoder_out, new_states
|
||||
|
||||
def get_init_states(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> List[torch.Tensor]:
|
||||
"""
|
||||
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
|
||||
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
|
||||
states[-2] is the cached left padding for ConvNeXt module,
|
||||
of shape (batch_size, num_channels, left_pad, num_freqs)
|
||||
states[-1] is processed_lens of shape (batch,), which records the number
|
||||
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
|
||||
"""
|
||||
states = self.encoder.get_init_states(batch_size, device)
|
||||
|
||||
embed_states = self.encoder_embed.get_init_states(batch_size, device)
|
||||
|
||||
states.append(embed_states)
|
||||
|
||||
processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device)
|
||||
states.append(processed_lens)
|
||||
|
||||
return states
|
||||
|
||||
|
||||
class OnnxDecoder(nn.Module):
|
||||
"""A wrapper for Decoder and the decoder_proj from the joiner"""
|
||||
|
||||
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
|
||||
super().__init__()
|
||||
self.decoder = decoder
|
||||
self.decoder_proj = decoder_proj
|
||||
|
||||
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, context_size).
|
||||
Returns
|
||||
Return a 2-D tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
need_pad = False
|
||||
decoder_output = self.decoder(y, need_pad=need_pad)
|
||||
decoder_output = decoder_output.squeeze(1)
|
||||
output = self.decoder_proj(decoder_output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class OnnxJoiner(nn.Module):
|
||||
"""A wrapper for the joiner"""
|
||||
|
||||
def __init__(self, output_linear: nn.Linear):
|
||||
super().__init__()
|
||||
self.output_linear = output_linear
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
decoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, vocab_size)
|
||||
"""
|
||||
logit = encoder_out + decoder_out
|
||||
logit = self.output_linear(torch.tanh(logit))
|
||||
return logit
|
||||
|
||||
|
||||
def export_encoder_model_onnx(
|
||||
encoder_model: OnnxEncoder,
|
||||
encoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
encoder_model.encoder.__class__.forward = (
|
||||
encoder_model.encoder.__class__.streaming_forward
|
||||
)
|
||||
|
||||
decode_chunk_len = encoder_model.chunk_size * 2
|
||||
# The encoder_embed subsample features (T - 7) // 2
|
||||
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
|
||||
T = decode_chunk_len + encoder_model.pad_length
|
||||
|
||||
x = torch.rand(1, T, 80, dtype=torch.float32)
|
||||
init_state = encoder_model.get_init_states()
|
||||
num_encoders = len(encoder_model.encoder.encoder_dim)
|
||||
logging.info(f"num_encoders: {num_encoders}")
|
||||
logging.info(f"len(init_state): {len(init_state)}")
|
||||
|
||||
inputs = {}
|
||||
input_names = ["x"]
|
||||
|
||||
outputs = {}
|
||||
output_names = ["encoder_out"]
|
||||
|
||||
def build_inputs_outputs(tensors, i):
|
||||
assert len(tensors) == 6, len(tensors)
|
||||
|
||||
# (downsample_left, batch_size, key_dim)
|
||||
name = f"cached_key_{i}"
|
||||
logging.info(f"{name}.shape: {tensors[0].shape}")
|
||||
inputs[name] = {1: "N"}
|
||||
outputs[f"new_{name}"] = {1: "N"}
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
# (1, batch_size, downsample_left, nonlin_attn_head_dim)
|
||||
name = f"cached_nonlin_attn_{i}"
|
||||
logging.info(f"{name}.shape: {tensors[1].shape}")
|
||||
inputs[name] = {1: "N"}
|
||||
outputs[f"new_{name}"] = {1: "N"}
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
# (downsample_left, batch_size, value_dim)
|
||||
name = f"cached_val1_{i}"
|
||||
logging.info(f"{name}.shape: {tensors[2].shape}")
|
||||
inputs[name] = {1: "N"}
|
||||
outputs[f"new_{name}"] = {1: "N"}
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
# (downsample_left, batch_size, value_dim)
|
||||
name = f"cached_val2_{i}"
|
||||
logging.info(f"{name}.shape: {tensors[3].shape}")
|
||||
inputs[name] = {1: "N"}
|
||||
outputs[f"new_{name}"] = {1: "N"}
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
# (batch_size, embed_dim, conv_left_pad)
|
||||
name = f"cached_conv1_{i}"
|
||||
logging.info(f"{name}.shape: {tensors[4].shape}")
|
||||
inputs[name] = {0: "N"}
|
||||
outputs[f"new_{name}"] = {0: "N"}
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
# (batch_size, embed_dim, conv_left_pad)
|
||||
name = f"cached_conv2_{i}"
|
||||
logging.info(f"{name}.shape: {tensors[5].shape}")
|
||||
inputs[name] = {0: "N"}
|
||||
outputs[f"new_{name}"] = {0: "N"}
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers))
|
||||
encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dim))
|
||||
cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernel))
|
||||
ds = encoder_model.encoder.downsampling_factor
|
||||
left_context_len = encoder_model.left_context_len
|
||||
left_context_len = [left_context_len // k for k in ds]
|
||||
left_context_len = ",".join(map(str, left_context_len))
|
||||
query_head_dims = ",".join(map(str, encoder_model.encoder.query_head_dim))
|
||||
value_head_dims = ",".join(map(str, encoder_model.encoder.value_head_dim))
|
||||
num_heads = ",".join(map(str, encoder_model.encoder.num_heads))
|
||||
|
||||
meta_data = {
|
||||
"model_type": "zipformer2",
|
||||
"version": "1",
|
||||
"model_author": "k2-fsa",
|
||||
"comment": "streaming zipformer2",
|
||||
"decode_chunk_len": str(decode_chunk_len), # 32
|
||||
"T": str(T), # 32+7+2*3=45
|
||||
"num_encoder_layers": num_encoder_layers,
|
||||
"encoder_dims": encoder_dims,
|
||||
"cnn_module_kernels": cnn_module_kernels,
|
||||
"left_context_len": left_context_len,
|
||||
"query_head_dims": query_head_dims,
|
||||
"value_head_dims": value_head_dims,
|
||||
"num_heads": num_heads,
|
||||
}
|
||||
logging.info(f"meta_data: {meta_data}")
|
||||
|
||||
for i in range(len(init_state[:-2]) // 6):
|
||||
build_inputs_outputs(init_state[i * 6 : (i + 1) * 6], i)
|
||||
|
||||
# (batch_size, channels, left_pad, freq)
|
||||
embed_states = init_state[-2]
|
||||
name = "embed_states"
|
||||
logging.info(f"{name}.shape: {embed_states.shape}")
|
||||
inputs[name] = {0: "N"}
|
||||
outputs[f"new_{name}"] = {0: "N"}
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
# (batch_size,)
|
||||
processed_lens = init_state[-1]
|
||||
name = "processed_lens"
|
||||
logging.info(f"{name}.shape: {processed_lens.shape}")
|
||||
inputs[name] = {0: "N"}
|
||||
outputs[f"new_{name}"] = {0: "N"}
|
||||
input_names.append(name)
|
||||
output_names.append(f"new_{name}")
|
||||
|
||||
logging.info(inputs)
|
||||
logging.info(outputs)
|
||||
logging.info(input_names)
|
||||
logging.info(output_names)
|
||||
|
||||
torch.onnx.export(
|
||||
encoder_model,
|
||||
(x, init_state),
|
||||
encoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes={
|
||||
"x": {0: "N"},
|
||||
"encoder_out": {0: "N"},
|
||||
**inputs,
|
||||
**outputs,
|
||||
},
|
||||
)
|
||||
|
||||
add_meta_data(filename=encoder_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
def export_decoder_model_onnx(
|
||||
decoder_model: OnnxDecoder,
|
||||
decoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the decoder model to ONNX format.
|
||||
|
||||
The exported model has one input:
|
||||
|
||||
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||
|
||||
and has one output:
|
||||
|
||||
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The decoder model to be exported.
|
||||
decoder_filename:
|
||||
Filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
context_size = decoder_model.decoder.context_size
|
||||
vocab_size = decoder_model.decoder.vocab_size
|
||||
|
||||
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||
decoder_model = torch.jit.script(decoder_model)
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
y,
|
||||
decoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["y"],
|
||||
output_names=["decoder_out"],
|
||||
dynamic_axes={
|
||||
"y": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
|
||||
meta_data = {
|
||||
"context_size": str(context_size),
|
||||
"vocab_size": str(vocab_size),
|
||||
}
|
||||
add_meta_data(filename=decoder_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
def export_joiner_model_onnx(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the joiner model to ONNX format.
|
||||
The exported joiner model has two inputs:
|
||||
|
||||
- encoder_out: a tensor of shape (N, joiner_dim)
|
||||
- decoder_out: a tensor of shape (N, joiner_dim)
|
||||
|
||||
and produces one output:
|
||||
|
||||
- logit: a tensor of shape (N, vocab_size)
|
||||
"""
|
||||
joiner_dim = joiner_model.output_linear.weight.shape[1]
|
||||
logging.info(f"joiner dim: {joiner_dim}")
|
||||
|
||||
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||
|
||||
torch.onnx.export(
|
||||
joiner_model,
|
||||
(projected_encoder_out, projected_decoder_out),
|
||||
joiner_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=[
|
||||
"encoder_out",
|
||||
"decoder_out",
|
||||
],
|
||||
output_names=["logit"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
"logit": {0: "N"},
|
||||
},
|
||||
)
|
||||
meta_data = {
|
||||
"joiner_dim": str(joiner_dim),
|
||||
}
|
||||
add_meta_data(filename=joiner_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||
params.blank_id = token_table["<blk>"]
|
||||
params.vocab_size = num_tokens(token_table) + 1
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
|
||||
encoder = OnnxEncoder(
|
||||
encoder=model.encoder,
|
||||
encoder_embed=model.encoder_embed,
|
||||
encoder_proj=model.joiner.encoder_proj,
|
||||
)
|
||||
|
||||
decoder = OnnxDecoder(
|
||||
decoder=model.decoder,
|
||||
decoder_proj=model.joiner.decoder_proj,
|
||||
)
|
||||
|
||||
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
|
||||
|
||||
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
|
||||
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
|
||||
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
|
||||
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
|
||||
logging.info(f"encoder parameters: {encoder_num_param}")
|
||||
logging.info(f"decoder parameters: {decoder_num_param}")
|
||||
logging.info(f"joiner parameters: {joiner_num_param}")
|
||||
logging.info(f"total parameters: {total_num_param}")
|
||||
|
||||
if params.iter > 0:
|
||||
suffix = f"iter-{params.iter}"
|
||||
else:
|
||||
suffix = f"epoch-{params.epoch}"
|
||||
|
||||
suffix += f"-avg-{params.avg}"
|
||||
suffix += f"-chunk-{params.chunk_size}"
|
||||
suffix += f"-left-{params.left_context_frames}"
|
||||
|
||||
opset_version = 13
|
||||
|
||||
logging.info("Exporting encoder")
|
||||
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
|
||||
export_encoder_model_onnx(
|
||||
encoder,
|
||||
encoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported encoder to {encoder_filename}")
|
||||
|
||||
logging.info("Exporting decoder")
|
||||
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
|
||||
export_decoder_model_onnx(
|
||||
decoder,
|
||||
decoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported decoder to {decoder_filename}")
|
||||
|
||||
logging.info("Exporting joiner")
|
||||
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
|
||||
export_joiner_model_onnx(
|
||||
joiner,
|
||||
joiner_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported joiner to {joiner_filename}")
|
||||
|
||||
# Generate int8 quantization models
|
||||
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||
|
||||
logging.info("Generate int8 quantization models")
|
||||
|
||||
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=encoder_filename,
|
||||
model_output=encoder_filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=decoder_filename,
|
||||
model_output=decoder_filename_int8,
|
||||
op_types_to_quantize=["MatMul", "Gather"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=joiner_filename,
|
||||
model_output=joiner_filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
1
egs/gigaspeech/ASR/zipformer/export-onnx-streaming.py
Symbolic link
1
egs/gigaspeech/ASR/zipformer/export-onnx-streaming.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/export-onnx-streaming.py
|
||||
@ -1,280 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads torchscript models, exported by `torch.jit.script()`
|
||||
and uses them to decode waves.
|
||||
You can use the following command to get the exported models:
|
||||
|
||||
./zipformer/export.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
--jit 1
|
||||
|
||||
Usage of this script:
|
||||
|
||||
./zipformer/jit_pretrained.py \
|
||||
--nn-model-filename ./zipformer/exp/cpu_jit.pt \
|
||||
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nn-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the torchscript model cpu_jit.pt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="""Path to tokens.txt.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float = 16000
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0].contiguous())
|
||||
return ans
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: torch.jit.ScriptModule,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
) -> List[List[int]]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,).
|
||||
Returns:
|
||||
Return the decoded results for each utterance.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
lengths=encoder_out_lens.cpu(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False,
|
||||
)
|
||||
|
||||
device = encoder_out.device
|
||||
blank_id = model.decoder.blank_id
|
||||
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
context_size = model.decoder.context_size
|
||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
hyps,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (N, context_size)
|
||||
|
||||
decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=torch.tensor([False]),
|
||||
).squeeze(1)
|
||||
|
||||
offset = 0
|
||||
for batch_size in batch_size_list:
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = packed_encoder_out.data[start:end]
|
||||
current_encoder_out = current_encoder_out
|
||||
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||
offset = end
|
||||
|
||||
decoder_out = decoder_out[:batch_size]
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
)
|
||||
# logits'shape (batch_size, vocab_size)
|
||||
|
||||
assert logits.ndim == 2, logits.shape
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
if v != blank_id:
|
||||
hyps[i].append(v)
|
||||
emitted = True
|
||||
if emitted:
|
||||
# update decoder output
|
||||
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||
decoder_input = torch.tensor(
|
||||
decoder_input,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=torch.tensor([False]),
|
||||
)
|
||||
decoder_out = decoder_out.squeeze(1)
|
||||
|
||||
sorted_ans = [h[context_size:] for h in hyps]
|
||||
ans = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
model = torch.jit.load(args.nn_model_filename)
|
||||
|
||||
model.eval()
|
||||
|
||||
model.to(device)
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=args.sound_files,
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features,
|
||||
batch_first=True,
|
||||
padding_value=math.log(1e-10),
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
features=features,
|
||||
feature_lengths=feature_lengths,
|
||||
)
|
||||
|
||||
hyps = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
|
||||
s = "\n"
|
||||
|
||||
token_table = k2.SymbolTable.from_file(args.tokens)
|
||||
|
||||
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||
text = ""
|
||||
for i in token_ids:
|
||||
text += token_table[i]
|
||||
return text.replace("▁", " ").strip()
|
||||
|
||||
for filename, hyp in zip(args.sound_files, hyps):
|
||||
words = token_ids_to_words(hyp)
|
||||
s += f"{filename}:\n{words}\n"
|
||||
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
1
egs/gigaspeech/ASR/zipformer/jit_pretrained.py
Symbolic link
1
egs/gigaspeech/ASR/zipformer/jit_pretrained.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/jit_pretrained.py
|
||||
@ -1,436 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads a checkpoint and uses it to decode waves.
|
||||
You can generate the checkpoint with the following command:
|
||||
|
||||
- For non-streaming model:
|
||||
|
||||
./zipformer/export.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--use-ctc 1 \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
--jit 1
|
||||
|
||||
- For streaming model:
|
||||
|
||||
./zipformer/export.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--use-ctc 1 \
|
||||
--causal 1 \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
--jit 1
|
||||
|
||||
Usage of this script:
|
||||
|
||||
(1) ctc-decoding
|
||||
./zipformer/jit_pretrained_ctc.py \
|
||||
--model-filename ./zipformer/exp/jit_script.pt \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--method ctc-decoding \
|
||||
--sample-rate 16000 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(2) 1best
|
||||
./zipformer/jit_pretrained_ctc.py \
|
||||
--model-filename ./zipformer/exp/jit_script.pt \
|
||||
--HLG data/lang_bpe_500/HLG.pt \
|
||||
--words-file data/lang_bpe_500/words.txt \
|
||||
--method 1best \
|
||||
--sample-rate 16000 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(3) nbest-rescoring
|
||||
./zipformer/jit_pretrained_ctc.py \
|
||||
--model-filename ./zipformer/exp/jit_script.pt \
|
||||
--HLG data/lang_bpe_500/HLG.pt \
|
||||
--words-file data/lang_bpe_500/words.txt \
|
||||
--G data/lm/G_4_gram.pt \
|
||||
--method nbest-rescoring \
|
||||
--sample-rate 16000 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(4) whole-lattice-rescoring
|
||||
./zipformer/jit_pretrained_ctc.py \
|
||||
--model-filename ./zipformer/exp/jit_script.pt \
|
||||
--HLG data/lang_bpe_500/HLG.pt \
|
||||
--words-file data/lang_bpe_500/words.txt \
|
||||
--G data/lm/G_4_gram.pt \
|
||||
--method whole-lattice-rescoring \
|
||||
--sample-rate 16000 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import torch
|
||||
import torchaudio
|
||||
from ctc_decode import get_decoding_params
|
||||
from export import num_tokens
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import get_params
|
||||
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
one_best_decoding,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.utils import get_texts
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the torchscript model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--words-file",
|
||||
type=str,
|
||||
help="""Path to words.txt.
|
||||
Used only when method is not ctc-decoding.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--HLG",
|
||||
type=str,
|
||||
help="""Path to HLG.pt.
|
||||
Used only when method is not ctc-decoding.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="""Path to tokens.txt.
|
||||
Used only when method is ctc-decoding.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="1best",
|
||||
help="""Decoding method.
|
||||
Possible values are:
|
||||
(0) ctc-decoding - Use CTC decoding. It uses a token table,
|
||||
i.e., lang_dir/token.txt, to convert
|
||||
word pieces to words. It needs neither a lexicon
|
||||
nor an n-gram LM.
|
||||
(1) 1best - Use the best path as decoding output. Only
|
||||
the transformer encoder output is used for decoding.
|
||||
We call it HLG decoding.
|
||||
(2) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||
rescore them with an LM, the path with
|
||||
the highest score is the decoding result.
|
||||
We call it HLG decoding + nbest n-gram LM rescoring.
|
||||
(3) whole-lattice-rescoring - Use an LM to rescore the
|
||||
decoding lattice and then use 1best to decode the
|
||||
rescored lattice.
|
||||
We call it HLG decoding + whole-lattice n-gram LM rescoring.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--G",
|
||||
type=str,
|
||||
help="""An LM for rescoring.
|
||||
Used only when method is
|
||||
whole-lattice-rescoring or nbest-rescoring.
|
||||
It's usually a 4-gram LM.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=100,
|
||||
help="""
|
||||
Used only when method is attention-decoder.
|
||||
It specifies the size of n-best list.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=1.3,
|
||||
help="""
|
||||
Used only when method is whole-lattice-rescoring and nbest-rescoring.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
(Note: You need to tune it on a dataset.)
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="""
|
||||
Used only when method is nbest-rescoring.
|
||||
It specifies the scale for lattice.scores when
|
||||
extracting n-best lists. A smaller value results in
|
||||
more unique number of paths with the risk of missing
|
||||
the best path.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float = 16000
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0].contiguous())
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
# add decoding params
|
||||
params.update(get_decoding_params())
|
||||
params.update(vars(args))
|
||||
|
||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||
params.vocab_size = num_tokens(token_table) + 1
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
model = torch.jit.load(args.model_filename)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = params.sample_rate
|
||||
opts.mel_opts.num_bins = params.feature_dim
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {params.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(features, feature_lengths)
|
||||
ctc_output = model.ctc_output(encoder_out) # (N, T, C)
|
||||
|
||||
batch_size = ctc_output.shape[0]
|
||||
supervision_segments = torch.tensor(
|
||||
[
|
||||
[i, 0, feature_lengths[i].item() // params.subsampling_factor]
|
||||
for i in range(batch_size)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
if params.method == "ctc-decoding":
|
||||
logging.info("Use CTC decoding")
|
||||
max_token_id = params.vocab_size - 1
|
||||
|
||||
H = k2.ctc_topo(
|
||||
max_token=max_token_id,
|
||||
modified=False,
|
||||
device=device,
|
||||
)
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=ctc_output,
|
||||
decoding_graph=H,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
token_ids = get_texts(best_path)
|
||||
hyps = [[token_table[i] for i in ids] for ids in token_ids]
|
||||
elif params.method in [
|
||||
"1best",
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
]:
|
||||
logging.info(f"Loading HLG from {params.HLG}")
|
||||
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
||||
HLG = HLG.to(device)
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
# For whole-lattice-rescoring and attention-decoder
|
||||
HLG.lm_scores = HLG.scores.clone()
|
||||
|
||||
if params.method in [
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
]:
|
||||
logging.info(f"Loading G from {params.G}")
|
||||
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
||||
G = G.to(device)
|
||||
if params.method == "whole-lattice-rescoring":
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
# it with the whole lattice later
|
||||
G = k2.add_epsilon_self_loops(G)
|
||||
G = k2.arc_sort(G)
|
||||
|
||||
# G.lm_scores is used to replace HLG.lm_scores during
|
||||
# LM rescoring.
|
||||
G.lm_scores = G.scores.clone()
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=ctc_output,
|
||||
decoding_graph=HLG,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
if params.method == "1best":
|
||||
logging.info("Use HLG decoding")
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
if params.method == "nbest-rescoring":
|
||||
logging.info("Use HLG decoding + LM rescoring")
|
||||
best_path_dict = rescore_with_n_best_list(
|
||||
lattice=lattice,
|
||||
G=G,
|
||||
num_paths=params.num_paths,
|
||||
lm_scale_list=[params.ngram_lm_scale],
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
best_path = next(iter(best_path_dict.values()))
|
||||
elif params.method == "whole-lattice-rescoring":
|
||||
logging.info("Use HLG decoding + LM rescoring")
|
||||
best_path_dict = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=[params.ngram_lm_scale],
|
||||
)
|
||||
best_path = next(iter(best_path_dict.values()))
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
||||
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
||||
else:
|
||||
raise ValueError(f"Unsupported decoding method: {params.method}")
|
||||
|
||||
s = "\n"
|
||||
if params.method == "ctc-decoding":
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
words = "".join(hyp)
|
||||
words = words.replace("▁", " ").strip()
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
elif params.method in [
|
||||
"1best",
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
]:
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
words = words.replace("▁", " ").strip()
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
1
egs/gigaspeech/ASR/zipformer/jit_pretrained_ctc.py
Symbolic link
1
egs/gigaspeech/ASR/zipformer/jit_pretrained_ctc.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/jit_pretrained_ctc.py
|
||||
@ -1,273 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# flake8: noqa
|
||||
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads torchscript models exported by `torch.jit.script()`
|
||||
and uses them to decode waves.
|
||||
You can use the following command to get the exported models:
|
||||
|
||||
./zipformer/export.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--causal 1 \
|
||||
--chunk-size 16 \
|
||||
--left-context-frames 128 \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
--jit 1
|
||||
|
||||
Usage of this script:
|
||||
|
||||
./zipformer/jit_pretrained_streaming.py \
|
||||
--nn-model-filename ./zipformer/exp-causal/jit_script_chunk_16_left_128.pt \
|
||||
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||
/path/to/foo.wav \
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List, Optional
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import torch
|
||||
import torchaudio
|
||||
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nn-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the torchscript model jit_script.pt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="""Path to tokens.txt.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_file",
|
||||
type=str,
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
def greedy_search(
|
||||
decoder: torch.jit.ScriptModule,
|
||||
joiner: torch.jit.ScriptModule,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: Optional[torch.Tensor] = None,
|
||||
hyp: Optional[List[int]] = None,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
):
|
||||
assert encoder_out.ndim == 2
|
||||
context_size = decoder.context_size
|
||||
blank_id = decoder.blank_id
|
||||
|
||||
if decoder_out is None:
|
||||
assert hyp is None, hyp
|
||||
hyp = [blank_id] * context_size
|
||||
decoder_input = torch.tensor(hyp, dtype=torch.int32, device=device).unsqueeze(0)
|
||||
# decoder_input.shape (1,, 1 context_size)
|
||||
decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1)
|
||||
else:
|
||||
assert decoder_out.ndim == 2
|
||||
assert hyp is not None, hyp
|
||||
|
||||
T = encoder_out.size(0)
|
||||
for i in range(T):
|
||||
cur_encoder_out = encoder_out[i : i + 1]
|
||||
joiner_out = joiner(cur_encoder_out, decoder_out).squeeze(0)
|
||||
y = joiner_out.argmax(dim=0).item()
|
||||
|
||||
if y != blank_id:
|
||||
hyp.append(y)
|
||||
decoder_input = hyp[-context_size:]
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
decoder_input, dtype=torch.int32, device=device
|
||||
).unsqueeze(0)
|
||||
decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1)
|
||||
|
||||
return hyp, decoder_out
|
||||
|
||||
|
||||
def create_streaming_feature_extractor(sample_rate) -> OnlineFeature:
|
||||
"""Create a CPU streaming feature extractor.
|
||||
|
||||
At present, we assume it returns a fbank feature extractor with
|
||||
fixed options. In the future, we will support passing in the options
|
||||
from outside.
|
||||
|
||||
Returns:
|
||||
Return a CPU streaming feature extractor.
|
||||
"""
|
||||
opts = FbankOptions()
|
||||
opts.device = "cpu"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = sample_rate
|
||||
opts.mel_opts.num_bins = 80
|
||||
return OnlineFbank(opts)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
model = torch.jit.load(args.nn_model_filename)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
encoder = model.encoder
|
||||
decoder = model.decoder
|
||||
joiner = model.joiner
|
||||
|
||||
token_table = k2.SymbolTable.from_file(args.tokens)
|
||||
context_size = decoder.context_size
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
online_fbank = create_streaming_feature_extractor(args.sample_rate)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_file}")
|
||||
wave_samples = read_sound_files(
|
||||
filenames=[args.sound_file],
|
||||
expected_sample_rate=args.sample_rate,
|
||||
)[0]
|
||||
logging.info(wave_samples.shape)
|
||||
|
||||
logging.info("Decoding started")
|
||||
|
||||
chunk_length = encoder.chunk_size * 2
|
||||
T = chunk_length + encoder.pad_length
|
||||
|
||||
logging.info(f"chunk_length: {chunk_length}")
|
||||
logging.info(f"T: {T}")
|
||||
|
||||
states = encoder.get_init_states(device=device)
|
||||
|
||||
tail_padding = torch.zeros(int(0.3 * args.sample_rate), dtype=torch.float32)
|
||||
|
||||
wave_samples = torch.cat([wave_samples, tail_padding])
|
||||
|
||||
chunk = int(0.25 * args.sample_rate) # 0.2 second
|
||||
num_processed_frames = 0
|
||||
|
||||
hyp = None
|
||||
decoder_out = None
|
||||
|
||||
start = 0
|
||||
while start < wave_samples.numel():
|
||||
logging.info(f"{start}/{wave_samples.numel()}")
|
||||
end = min(start + chunk, wave_samples.numel())
|
||||
samples = wave_samples[start:end]
|
||||
start += chunk
|
||||
online_fbank.accept_waveform(
|
||||
sampling_rate=args.sample_rate,
|
||||
waveform=samples,
|
||||
)
|
||||
while online_fbank.num_frames_ready - num_processed_frames >= T:
|
||||
frames = []
|
||||
for i in range(T):
|
||||
frames.append(online_fbank.get_frame(num_processed_frames + i))
|
||||
frames = torch.cat(frames, dim=0).to(device).unsqueeze(0)
|
||||
x_lens = torch.tensor([T], dtype=torch.int32, device=device)
|
||||
encoder_out, out_lens, states = encoder(
|
||||
features=frames,
|
||||
feature_lengths=x_lens,
|
||||
states=states,
|
||||
)
|
||||
num_processed_frames += chunk_length
|
||||
|
||||
hyp, decoder_out = greedy_search(
|
||||
decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp, device=device
|
||||
)
|
||||
|
||||
text = ""
|
||||
for i in hyp[context_size:]:
|
||||
text += token_table[i]
|
||||
text = text.replace("▁", " ").strip()
|
||||
|
||||
logging.info(args.sound_file)
|
||||
logging.info(text)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
torch.set_num_threads(4)
|
||||
torch.set_num_interop_threads(1)
|
||||
torch._C._jit_set_profiling_executor(False)
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
torch._C._set_graph_executor_optimize(False)
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
1
egs/gigaspeech/ASR/zipformer/jit_pretrained_streaming.py
Symbolic link
1
egs/gigaspeech/ASR/zipformer/jit_pretrained_streaming.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/jit_pretrained_streaming.py
|
||||
@ -1,240 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script checks that exported onnx models produce the same output
|
||||
with the given torchscript model for the same input.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export the model via torchscript (torch.jit.script())
|
||||
|
||||
./zipformer/export.py \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp/ \
|
||||
--jit 1
|
||||
|
||||
It will generate the following file in $repo/exp:
|
||||
- jit_script.pt
|
||||
|
||||
3. Export the model to ONNX
|
||||
|
||||
./zipformer/export-onnx.py \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp/
|
||||
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
- encoder-epoch-99-avg-1.onnx
|
||||
- decoder-epoch-99-avg-1.onnx
|
||||
- joiner-epoch-99-avg-1.onnx
|
||||
|
||||
4. Run this file
|
||||
|
||||
./zipformer/onnx_check.py \
|
||||
--jit-filename $repo/exp/jit_script.pt \
|
||||
--onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||
--onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||
--onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from onnx_pretrained import OnnxModel
|
||||
|
||||
from icefall import is_module_available
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the torchscript model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-encoder-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the onnx encoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-decoder-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the onnx decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-joiner-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the onnx joiner model",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def test_encoder(
|
||||
torch_model: torch.jit.ScriptModule,
|
||||
onnx_model: OnnxModel,
|
||||
):
|
||||
C = 80
|
||||
for i in range(3):
|
||||
N = torch.randint(low=1, high=20, size=(1,)).item()
|
||||
T = torch.randint(low=30, high=50, size=(1,)).item()
|
||||
logging.info(f"test_encoder: iter {i}, N={N}, T={T}")
|
||||
|
||||
x = torch.rand(N, T, C)
|
||||
x_lens = torch.randint(low=30, high=T + 1, size=(N,))
|
||||
x_lens[0] = T
|
||||
|
||||
torch_encoder_out, torch_encoder_out_lens = torch_model.encoder(x, x_lens)
|
||||
torch_encoder_out = torch_model.joiner.encoder_proj(torch_encoder_out)
|
||||
|
||||
onnx_encoder_out, onnx_encoder_out_lens = onnx_model.run_encoder(x, x_lens)
|
||||
|
||||
assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-05), (
|
||||
(torch_encoder_out - onnx_encoder_out).abs().max()
|
||||
)
|
||||
|
||||
|
||||
def test_decoder(
|
||||
torch_model: torch.jit.ScriptModule,
|
||||
onnx_model: OnnxModel,
|
||||
):
|
||||
context_size = onnx_model.context_size
|
||||
vocab_size = onnx_model.vocab_size
|
||||
for i in range(10):
|
||||
N = torch.randint(1, 100, size=(1,)).item()
|
||||
logging.info(f"test_decoder: iter {i}, N={N}")
|
||||
x = torch.randint(
|
||||
low=1,
|
||||
high=vocab_size,
|
||||
size=(N, context_size),
|
||||
dtype=torch.int64,
|
||||
)
|
||||
torch_decoder_out = torch_model.decoder(x, need_pad=torch.tensor([False]))
|
||||
torch_decoder_out = torch_model.joiner.decoder_proj(torch_decoder_out)
|
||||
torch_decoder_out = torch_decoder_out.squeeze(1)
|
||||
|
||||
onnx_decoder_out = onnx_model.run_decoder(x)
|
||||
assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), (
|
||||
(torch_decoder_out - onnx_decoder_out).abs().max()
|
||||
)
|
||||
|
||||
|
||||
def test_joiner(
|
||||
torch_model: torch.jit.ScriptModule,
|
||||
onnx_model: OnnxModel,
|
||||
):
|
||||
encoder_dim = torch_model.joiner.encoder_proj.weight.shape[1]
|
||||
decoder_dim = torch_model.joiner.decoder_proj.weight.shape[1]
|
||||
for i in range(10):
|
||||
N = torch.randint(1, 100, size=(1,)).item()
|
||||
logging.info(f"test_joiner: iter {i}, N={N}")
|
||||
encoder_out = torch.rand(N, encoder_dim)
|
||||
decoder_out = torch.rand(N, decoder_dim)
|
||||
|
||||
projected_encoder_out = torch_model.joiner.encoder_proj(encoder_out)
|
||||
projected_decoder_out = torch_model.joiner.decoder_proj(decoder_out)
|
||||
|
||||
torch_joiner_out = torch_model.joiner(encoder_out, decoder_out)
|
||||
onnx_joiner_out = onnx_model.run_joiner(
|
||||
projected_encoder_out, projected_decoder_out
|
||||
)
|
||||
|
||||
assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), (
|
||||
(torch_joiner_out - onnx_joiner_out).abs().max()
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
torch_model = torch.jit.load(args.jit_filename)
|
||||
|
||||
onnx_model = OnnxModel(
|
||||
encoder_model_filename=args.onnx_encoder_filename,
|
||||
decoder_model_filename=args.onnx_decoder_filename,
|
||||
joiner_model_filename=args.onnx_joiner_filename,
|
||||
)
|
||||
|
||||
logging.info("Test encoder")
|
||||
test_encoder(torch_model, onnx_model)
|
||||
|
||||
logging.info("Test decoder")
|
||||
test_decoder(torch_model, onnx_model)
|
||||
|
||||
logging.info("Test joiner")
|
||||
test_joiner(torch_model, onnx_model)
|
||||
logging.info("Finished checking ONNX models")
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
# See https://github.com/pytorch/pytorch/issues/38342
|
||||
# and https://github.com/pytorch/pytorch/issues/33354
|
||||
#
|
||||
# If we don't do this, the delay increases whenever there is
|
||||
# a new request that changes the actual batch size.
|
||||
# If you use `py-spy dump --pid <server-pid> --native`, you will
|
||||
# see a lot of time is spent in re-compiling the torch script model.
|
||||
torch._C._jit_set_profiling_executor(False)
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
torch._C._set_graph_executor_optimize(False)
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20220727)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
1
egs/gigaspeech/ASR/zipformer/onnx_check.py
Symbolic link
1
egs/gigaspeech/ASR/zipformer/onnx_check.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/onnx_check.py
|
||||
@ -1,325 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao,
|
||||
# Xiaoyu Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads ONNX exported models and uses them to decode the test sets.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export the model to ONNX
|
||||
|
||||
./zipformer/export-onnx.py \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp \
|
||||
--causal False
|
||||
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
- encoder-epoch-99-avg-1.onnx
|
||||
- decoder-epoch-99-avg-1.onnx
|
||||
- joiner-epoch-99-avg-1.onnx
|
||||
|
||||
2. Run this file
|
||||
|
||||
./zipformer/onnx_decode.py \
|
||||
--exp-dir $repo/exp \
|
||||
--max-duration 600 \
|
||||
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
|
||||
from onnx_pretrained import greedy_search, OnnxModel
|
||||
|
||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||
from k2 import SymbolTable
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the encoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the decoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipformer/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="""Path to tokens.txt.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="Valid values are greedy_search and modified_beam_search",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
model: OnnxModel, token_table: SymbolTable, batch: dict
|
||||
) -> List[List[str]]:
|
||||
"""Decode one batch and return the result.
|
||||
Currently it only greedy_search is supported.
|
||||
|
||||
Args:
|
||||
model:
|
||||
The neural model.
|
||||
token_table:
|
||||
The token table.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
|
||||
Returns:
|
||||
Return the decoded results for each utterance.
|
||||
"""
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(dtype=torch.int64)
|
||||
|
||||
encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens)
|
||||
|
||||
hyps = greedy_search(
|
||||
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
|
||||
)
|
||||
|
||||
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||
text = ""
|
||||
for i in token_ids:
|
||||
text += token_table[i]
|
||||
return text.replace("▁", " ").strip()
|
||||
|
||||
hyps = [token_ids_to_words(h).split() for h in hyps]
|
||||
return hyps
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
model: nn.Module,
|
||||
token_table: SymbolTable,
|
||||
) -> Tuple[List[Tuple[str, List[str], List[str]]], float]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
model:
|
||||
The neural model.
|
||||
token_table:
|
||||
The token table.
|
||||
|
||||
Returns:
|
||||
- A list of tuples. Each tuple contains three elements:
|
||||
- cut_id,
|
||||
- reference transcript,
|
||||
- predicted result.
|
||||
- The total duration (in seconds) of the dataset.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
log_interval = 10
|
||||
total_duration = 0
|
||||
|
||||
results = []
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]])
|
||||
|
||||
hyps = decode_one_batch(model=model, token_table=token_table, batch=batch)
|
||||
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results.extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
|
||||
return results, total_duration
|
||||
|
||||
|
||||
def save_results(
|
||||
res_dir: Path,
|
||||
test_set_name: str,
|
||||
results: List[Tuple[str, List[str], List[str]]],
|
||||
):
|
||||
recog_path = res_dir / f"recogs-{test_set_name}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = res_dir / f"errs-{test_set_name}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True)
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
errs_info = res_dir / f"wer-summary-{test_set_name}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("WER", file=f)
|
||||
print(wer, file=f)
|
||||
|
||||
s = "\nFor {}, WER is {}:\n".format(test_set_name, wer)
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert (
|
||||
args.decoding_method == "greedy_search"
|
||||
), "Only supports greedy_search currently."
|
||||
res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}"
|
||||
|
||||
setup_logger(f"{res_dir}/log-decode")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
token_table = SymbolTable.from_file(args.tokens)
|
||||
|
||||
logging.info(vars(args))
|
||||
|
||||
logging.info("About to create model")
|
||||
model = OnnxModel(
|
||||
encoder_model_filename=args.encoder_model_filename,
|
||||
decoder_model_filename=args.decoder_model_filename,
|
||||
joiner_model_filename=args.joiner_model_filename,
|
||||
)
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
|
||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
start_time = time.time()
|
||||
results, total_duration = decode_dataset(
|
||||
dl=test_dl, model=model, token_table=token_table
|
||||
)
|
||||
end_time = time.time()
|
||||
elapsed_seconds = end_time - start_time
|
||||
rtf = elapsed_seconds / total_duration
|
||||
|
||||
logging.info(f"Elapsed time: {elapsed_seconds:.3f} s")
|
||||
logging.info(f"Wave duration: {total_duration:.3f} s")
|
||||
logging.info(
|
||||
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
|
||||
)
|
||||
|
||||
save_results(res_dir=res_dir, test_set_name=test_set, results=results)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
egs/gigaspeech/ASR/zipformer/onnx_decode.py
Symbolic link
1
egs/gigaspeech/ASR/zipformer/onnx_decode.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/onnx_decode.py
|
||||
@ -1,546 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
|
||||
|
||||
"""
|
||||
This script loads ONNX models exported by ./export-onnx-streaming.py
|
||||
and uses them to decode waves.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export the model to ONNX
|
||||
|
||||
./zipformer/export-onnx-streaming.py \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp \
|
||||
--num-encoder-layers "2,2,3,4,3,2" \
|
||||
--downsampling-factor "1,2,4,8,4,2" \
|
||||
--feedforward-dim "512,768,1024,1536,1024,768" \
|
||||
--num-heads "4,4,4,8,4,4" \
|
||||
--encoder-dim "192,256,384,512,384,256" \
|
||||
--query-head-dim 32 \
|
||||
--value-head-dim 12 \
|
||||
--pos-head-dim 4 \
|
||||
--pos-dim 48 \
|
||||
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
||||
--cnn-module-kernel "31,31,15,15,15,31" \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512 \
|
||||
--causal True \
|
||||
--chunk-size 16 \
|
||||
--left-context-frames 64
|
||||
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
- encoder-epoch-99-avg-1.onnx
|
||||
- decoder-epoch-99-avg-1.onnx
|
||||
- joiner-epoch-99-avg-1.onnx
|
||||
|
||||
3. Run this file with the exported ONNX models
|
||||
|
||||
./zipformer/onnx_pretrained-streaming.py \
|
||||
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
$repo/test_wavs/1089-134686-0001.wav
|
||||
|
||||
Note: Even though this script only supports decoding a single file,
|
||||
the exported ONNX models do support batch processing.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torchaudio
|
||||
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the encoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the decoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="""Path to tokens.txt.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_file",
|
||||
type=str,
|
||||
help="The input sound file to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(
|
||||
self,
|
||||
encoder_model_filename: str,
|
||||
decoder_model_filename: str,
|
||||
joiner_model_filename: str,
|
||||
):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.session_opts = session_opts
|
||||
|
||||
self.init_encoder(encoder_model_filename)
|
||||
self.init_decoder(decoder_model_filename)
|
||||
self.init_joiner(joiner_model_filename)
|
||||
|
||||
def init_encoder(self, encoder_model_filename: str):
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self.init_encoder_states()
|
||||
|
||||
def init_encoder_states(self, batch_size: int = 1):
|
||||
encoder_meta = self.encoder.get_modelmeta().custom_metadata_map
|
||||
logging.info(f"encoder_meta={encoder_meta}")
|
||||
|
||||
model_type = encoder_meta["model_type"]
|
||||
assert model_type == "zipformer2", model_type
|
||||
|
||||
decode_chunk_len = int(encoder_meta["decode_chunk_len"])
|
||||
T = int(encoder_meta["T"])
|
||||
|
||||
num_encoder_layers = encoder_meta["num_encoder_layers"]
|
||||
encoder_dims = encoder_meta["encoder_dims"]
|
||||
cnn_module_kernels = encoder_meta["cnn_module_kernels"]
|
||||
left_context_len = encoder_meta["left_context_len"]
|
||||
query_head_dims = encoder_meta["query_head_dims"]
|
||||
value_head_dims = encoder_meta["value_head_dims"]
|
||||
num_heads = encoder_meta["num_heads"]
|
||||
|
||||
def to_int_list(s):
|
||||
return list(map(int, s.split(",")))
|
||||
|
||||
num_encoder_layers = to_int_list(num_encoder_layers)
|
||||
encoder_dims = to_int_list(encoder_dims)
|
||||
cnn_module_kernels = to_int_list(cnn_module_kernels)
|
||||
left_context_len = to_int_list(left_context_len)
|
||||
query_head_dims = to_int_list(query_head_dims)
|
||||
value_head_dims = to_int_list(value_head_dims)
|
||||
num_heads = to_int_list(num_heads)
|
||||
|
||||
logging.info(f"decode_chunk_len: {decode_chunk_len}")
|
||||
logging.info(f"T: {T}")
|
||||
logging.info(f"num_encoder_layers: {num_encoder_layers}")
|
||||
logging.info(f"encoder_dims: {encoder_dims}")
|
||||
logging.info(f"cnn_module_kernels: {cnn_module_kernels}")
|
||||
logging.info(f"left_context_len: {left_context_len}")
|
||||
logging.info(f"query_head_dims: {query_head_dims}")
|
||||
logging.info(f"value_head_dims: {value_head_dims}")
|
||||
logging.info(f"num_heads: {num_heads}")
|
||||
|
||||
num_encoders = len(num_encoder_layers)
|
||||
|
||||
self.states = []
|
||||
for i in range(num_encoders):
|
||||
num_layers = num_encoder_layers[i]
|
||||
key_dim = query_head_dims[i] * num_heads[i]
|
||||
embed_dim = encoder_dims[i]
|
||||
nonlin_attn_head_dim = 3 * embed_dim // 4
|
||||
value_dim = value_head_dims[i] * num_heads[i]
|
||||
conv_left_pad = cnn_module_kernels[i] // 2
|
||||
|
||||
for layer in range(num_layers):
|
||||
cached_key = torch.zeros(
|
||||
left_context_len[i], batch_size, key_dim
|
||||
).numpy()
|
||||
cached_nonlin_attn = torch.zeros(
|
||||
1, batch_size, left_context_len[i], nonlin_attn_head_dim
|
||||
).numpy()
|
||||
cached_val1 = torch.zeros(
|
||||
left_context_len[i], batch_size, value_dim
|
||||
).numpy()
|
||||
cached_val2 = torch.zeros(
|
||||
left_context_len[i], batch_size, value_dim
|
||||
).numpy()
|
||||
cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy()
|
||||
cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy()
|
||||
self.states += [
|
||||
cached_key,
|
||||
cached_nonlin_attn,
|
||||
cached_val1,
|
||||
cached_val2,
|
||||
cached_conv1,
|
||||
cached_conv2,
|
||||
]
|
||||
embed_states = torch.zeros(batch_size, 128, 3, 19).numpy()
|
||||
self.states.append(embed_states)
|
||||
processed_lens = torch.zeros(batch_size, dtype=torch.int64).numpy()
|
||||
self.states.append(processed_lens)
|
||||
|
||||
self.num_encoders = num_encoders
|
||||
|
||||
self.segment = T
|
||||
self.offset = decode_chunk_len
|
||||
|
||||
def init_decoder(self, decoder_model_filename: str):
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
self.context_size = int(decoder_meta["context_size"])
|
||||
self.vocab_size = int(decoder_meta["vocab_size"])
|
||||
|
||||
logging.info(f"context_size: {self.context_size}")
|
||||
logging.info(f"vocab_size: {self.vocab_size}")
|
||||
|
||||
def init_joiner(self, joiner_model_filename: str):
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
self.joiner_dim = int(joiner_meta["joiner_dim"])
|
||||
|
||||
logging.info(f"joiner_dim: {self.joiner_dim}")
|
||||
|
||||
def _build_encoder_input_output(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> Tuple[Dict[str, np.ndarray], List[str]]:
|
||||
encoder_input = {"x": x.numpy()}
|
||||
encoder_output = ["encoder_out"]
|
||||
|
||||
def build_inputs_outputs(tensors, i):
|
||||
assert len(tensors) == 6, len(tensors)
|
||||
|
||||
# (downsample_left, batch_size, key_dim)
|
||||
name = f"cached_key_{i}"
|
||||
encoder_input[name] = tensors[0]
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
# (1, batch_size, downsample_left, nonlin_attn_head_dim)
|
||||
name = f"cached_nonlin_attn_{i}"
|
||||
encoder_input[name] = tensors[1]
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
# (downsample_left, batch_size, value_dim)
|
||||
name = f"cached_val1_{i}"
|
||||
encoder_input[name] = tensors[2]
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
# (downsample_left, batch_size, value_dim)
|
||||
name = f"cached_val2_{i}"
|
||||
encoder_input[name] = tensors[3]
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
# (batch_size, embed_dim, conv_left_pad)
|
||||
name = f"cached_conv1_{i}"
|
||||
encoder_input[name] = tensors[4]
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
# (batch_size, embed_dim, conv_left_pad)
|
||||
name = f"cached_conv2_{i}"
|
||||
encoder_input[name] = tensors[5]
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
for i in range(len(self.states[:-2]) // 6):
|
||||
build_inputs_outputs(self.states[i * 6 : (i + 1) * 6], i)
|
||||
|
||||
# (batch_size, channels, left_pad, freq)
|
||||
name = "embed_states"
|
||||
embed_states = self.states[-2]
|
||||
encoder_input[name] = embed_states
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
# (batch_size,)
|
||||
name = "processed_lens"
|
||||
processed_lens = self.states[-1]
|
||||
encoder_input[name] = processed_lens
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
return encoder_input, encoder_output
|
||||
|
||||
def _update_states(self, states: List[np.ndarray]):
|
||||
self.states = states
|
||||
|
||||
def run_encoder(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
Returns:
|
||||
Return a 3-D tensor of shape (N, T', joiner_dim) where
|
||||
T' is usually equal to ((T-7)//2+1)//2
|
||||
"""
|
||||
encoder_input, encoder_output_names = self._build_encoder_input_output(x)
|
||||
|
||||
out = self.encoder.run(encoder_output_names, encoder_input)
|
||||
|
||||
self._update_states(out[1:])
|
||||
|
||||
return torch.from_numpy(out[0])
|
||||
|
||||
def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
decoder_input:
|
||||
A 2-D tensor of shape (N, context_size)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
out = self.decoder.run(
|
||||
[self.decoder.get_outputs()[0].name],
|
||||
{self.decoder.get_inputs()[0].name: decoder_input.numpy()},
|
||||
)[0]
|
||||
|
||||
return torch.from_numpy(out)
|
||||
|
||||
def run_joiner(
|
||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
decoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, vocab_size)
|
||||
"""
|
||||
out = self.joiner.run(
|
||||
[self.joiner.get_outputs()[0].name],
|
||||
{
|
||||
self.joiner.get_inputs()[0].name: encoder_out.numpy(),
|
||||
self.joiner.get_inputs()[1].name: decoder_out.numpy(),
|
||||
},
|
||||
)[0]
|
||||
|
||||
return torch.from_numpy(out)
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0].contiguous())
|
||||
return ans
|
||||
|
||||
|
||||
def create_streaming_feature_extractor() -> OnlineFeature:
|
||||
"""Create a CPU streaming feature extractor.
|
||||
|
||||
At present, we assume it returns a fbank feature extractor with
|
||||
fixed options. In the future, we will support passing in the options
|
||||
from outside.
|
||||
|
||||
Returns:
|
||||
Return a CPU streaming feature extractor.
|
||||
"""
|
||||
opts = FbankOptions()
|
||||
opts.device = "cpu"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
return OnlineFbank(opts)
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: OnnxModel,
|
||||
encoder_out: torch.Tensor,
|
||||
context_size: int,
|
||||
decoder_out: Optional[torch.Tensor] = None,
|
||||
hyp: Optional[List[int]] = None,
|
||||
) -> List[int]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (1, T, joiner_dim)
|
||||
context_size:
|
||||
The context size of the decoder model.
|
||||
decoder_out:
|
||||
Optional. Decoder output of the previous chunk.
|
||||
hyp:
|
||||
Decoding results for previous chunks.
|
||||
Returns:
|
||||
Return the decoded results so far.
|
||||
"""
|
||||
|
||||
blank_id = 0
|
||||
|
||||
if decoder_out is None:
|
||||
assert hyp is None, hyp
|
||||
hyp = [blank_id] * context_size
|
||||
decoder_input = torch.tensor([hyp], dtype=torch.int64)
|
||||
decoder_out = model.run_decoder(decoder_input)
|
||||
else:
|
||||
assert hyp is not None, hyp
|
||||
|
||||
encoder_out = encoder_out.squeeze(0)
|
||||
T = encoder_out.size(0)
|
||||
for t in range(T):
|
||||
cur_encoder_out = encoder_out[t : t + 1]
|
||||
joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0)
|
||||
y = joiner_out.argmax(dim=0).item()
|
||||
if y != blank_id:
|
||||
hyp.append(y)
|
||||
decoder_input = hyp[-context_size:]
|
||||
decoder_input = torch.tensor([decoder_input], dtype=torch.int64)
|
||||
decoder_out = model.run_decoder(decoder_input)
|
||||
|
||||
return hyp, decoder_out
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
model = OnnxModel(
|
||||
encoder_model_filename=args.encoder_model_filename,
|
||||
decoder_model_filename=args.decoder_model_filename,
|
||||
joiner_model_filename=args.joiner_model_filename,
|
||||
)
|
||||
|
||||
sample_rate = 16000
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
online_fbank = create_streaming_feature_extractor()
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_file}")
|
||||
waves = read_sound_files(
|
||||
filenames=[args.sound_file],
|
||||
expected_sample_rate=sample_rate,
|
||||
)[0]
|
||||
|
||||
tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32)
|
||||
wave_samples = torch.cat([waves, tail_padding])
|
||||
|
||||
num_processed_frames = 0
|
||||
segment = model.segment
|
||||
offset = model.offset
|
||||
|
||||
context_size = model.context_size
|
||||
hyp = None
|
||||
decoder_out = None
|
||||
|
||||
chunk = int(1 * sample_rate) # 1 second
|
||||
start = 0
|
||||
while start < wave_samples.numel():
|
||||
end = min(start + chunk, wave_samples.numel())
|
||||
samples = wave_samples[start:end]
|
||||
start += chunk
|
||||
|
||||
online_fbank.accept_waveform(
|
||||
sampling_rate=sample_rate,
|
||||
waveform=samples,
|
||||
)
|
||||
|
||||
while online_fbank.num_frames_ready - num_processed_frames >= segment:
|
||||
frames = []
|
||||
for i in range(segment):
|
||||
frames.append(online_fbank.get_frame(num_processed_frames + i))
|
||||
num_processed_frames += offset
|
||||
frames = torch.cat(frames, dim=0)
|
||||
frames = frames.unsqueeze(0)
|
||||
encoder_out = model.run_encoder(frames)
|
||||
hyp, decoder_out = greedy_search(
|
||||
model,
|
||||
encoder_out,
|
||||
context_size,
|
||||
decoder_out,
|
||||
hyp,
|
||||
)
|
||||
|
||||
token_table = k2.SymbolTable.from_file(args.tokens)
|
||||
|
||||
text = ""
|
||||
for i in hyp[context_size:]:
|
||||
text += token_table[i]
|
||||
text = text.replace("▁", " ").strip()
|
||||
|
||||
logging.info(args.sound_file)
|
||||
logging.info(text)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
1
egs/gigaspeech/ASR/zipformer/onnx_pretrained-streaming.py
Symbolic link
1
egs/gigaspeech/ASR/zipformer/onnx_pretrained-streaming.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/onnx_pretrained-streaming.py
|
||||
@ -1,421 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads ONNX models and uses them to decode waves.
|
||||
You can use the following command to get the exported models:
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Download the pre-trained model
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
pushd $repo
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
|
||||
cd exp
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
2. Export the model to ONNX
|
||||
|
||||
./zipformer/export-onnx.py \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp \
|
||||
--causal False
|
||||
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
- encoder-epoch-99-avg-1.onnx
|
||||
- decoder-epoch-99-avg-1.onnx
|
||||
- joiner-epoch-99-avg-1.onnx
|
||||
|
||||
3. Run this file
|
||||
|
||||
./zipformer/onnx_pretrained.py \
|
||||
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List, Tuple
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the encoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the decoder onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="""Path to tokens.txt.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(
|
||||
self,
|
||||
encoder_model_filename: str,
|
||||
decoder_model_filename: str,
|
||||
joiner_model_filename: str,
|
||||
):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 4
|
||||
|
||||
self.session_opts = session_opts
|
||||
|
||||
self.init_encoder(encoder_model_filename)
|
||||
self.init_decoder(decoder_model_filename)
|
||||
self.init_joiner(joiner_model_filename)
|
||||
|
||||
def init_encoder(self, encoder_model_filename: str):
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_decoder(self, decoder_model_filename: str):
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||
self.context_size = int(decoder_meta["context_size"])
|
||||
self.vocab_size = int(decoder_meta["vocab_size"])
|
||||
|
||||
logging.info(f"context_size: {self.context_size}")
|
||||
logging.info(f"vocab_size: {self.vocab_size}")
|
||||
|
||||
def init_joiner(self, joiner_model_filename: str):
|
||||
self.joiner = ort.InferenceSession(
|
||||
joiner_model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||
self.joiner_dim = int(joiner_meta["joiner_dim"])
|
||||
|
||||
logging.info(f"joiner_dim: {self.joiner_dim}")
|
||||
|
||||
def run_encoder(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
x_lens:
|
||||
A 2-D tensor of shape (N,). Its dtype is torch.int64
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- encoder_out, its shape is (N, T', joiner_dim)
|
||||
- encoder_out_lens, its shape is (N,)
|
||||
"""
|
||||
out = self.encoder.run(
|
||||
[
|
||||
self.encoder.get_outputs()[0].name,
|
||||
self.encoder.get_outputs()[1].name,
|
||||
],
|
||||
{
|
||||
self.encoder.get_inputs()[0].name: x.numpy(),
|
||||
self.encoder.get_inputs()[1].name: x_lens.numpy(),
|
||||
},
|
||||
)
|
||||
return torch.from_numpy(out[0]), torch.from_numpy(out[1])
|
||||
|
||||
def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
decoder_input:
|
||||
A 2-D tensor of shape (N, context_size)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
out = self.decoder.run(
|
||||
[self.decoder.get_outputs()[0].name],
|
||||
{self.decoder.get_inputs()[0].name: decoder_input.numpy()},
|
||||
)[0]
|
||||
|
||||
return torch.from_numpy(out)
|
||||
|
||||
def run_joiner(
|
||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
decoder_out:
|
||||
A 2-D tensor of shape (N, joiner_dim)
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, vocab_size)
|
||||
"""
|
||||
out = self.joiner.run(
|
||||
[self.joiner.get_outputs()[0].name],
|
||||
{
|
||||
self.joiner.get_inputs()[0].name: encoder_out.numpy(),
|
||||
self.joiner.get_inputs()[1].name: decoder_out.numpy(),
|
||||
},
|
||||
)[0]
|
||||
|
||||
return torch.from_numpy(out)
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: OnnxModel,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
) -> List[List[int]]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (N, T, joiner_dim)
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,).
|
||||
Returns:
|
||||
Return the decoded results for each utterance.
|
||||
"""
|
||||
assert encoder_out.ndim == 3, encoder_out.shape
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
lengths=encoder_out_lens.cpu(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False,
|
||||
)
|
||||
|
||||
blank_id = 0 # hard-code to 0
|
||||
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
context_size = model.context_size
|
||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
hyps,
|
||||
dtype=torch.int64,
|
||||
) # (N, context_size)
|
||||
|
||||
decoder_out = model.run_decoder(decoder_input)
|
||||
|
||||
offset = 0
|
||||
for batch_size in batch_size_list:
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = packed_encoder_out.data[start:end]
|
||||
# current_encoder_out's shape: (batch_size, joiner_dim)
|
||||
offset = end
|
||||
|
||||
decoder_out = decoder_out[:batch_size]
|
||||
logits = model.run_joiner(current_encoder_out, decoder_out)
|
||||
|
||||
# logits'shape (batch_size, vocab_size)
|
||||
|
||||
assert logits.ndim == 2, logits.shape
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
if v != blank_id:
|
||||
hyps[i].append(v)
|
||||
emitted = True
|
||||
if emitted:
|
||||
# update decoder output
|
||||
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||
decoder_input = torch.tensor(
|
||||
decoder_input,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = model.run_decoder(decoder_input)
|
||||
|
||||
sorted_ans = [h[context_size:] for h in hyps]
|
||||
ans = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
model = OnnxModel(
|
||||
encoder_model_filename=args.encoder_model_filename,
|
||||
decoder_model_filename=args.decoder_model_filename,
|
||||
joiner_model_filename=args.joiner_model_filename,
|
||||
)
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = "cpu"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = args.sample_rate
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=args.sound_files,
|
||||
expected_sample_rate=args.sample_rate,
|
||||
)
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features,
|
||||
batch_first=True,
|
||||
padding_value=math.log(1e-10),
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
|
||||
encoder_out, encoder_out_lens = model.run_encoder(features, feature_lengths)
|
||||
|
||||
hyps = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
s = "\n"
|
||||
|
||||
token_table = k2.SymbolTable.from_file(args.tokens)
|
||||
|
||||
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||
text = ""
|
||||
for i in token_ids:
|
||||
text += token_table[i]
|
||||
return text.replace("▁", " ").strip()
|
||||
|
||||
for filename, hyp in zip(args.sound_files, hyps):
|
||||
words = token_ids_to_words(hyp)
|
||||
s += f"{filename}:\n{words}\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
1
egs/gigaspeech/ASR/zipformer/onnx_pretrained.py
Symbolic link
1
egs/gigaspeech/ASR/zipformer/onnx_pretrained.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/onnx_pretrained.py
|
||||
@ -1,213 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
"""
|
||||
This script loads ONNX models and uses them to decode waves.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Please follow ./export-onnx-ctc.py to get the onnx model.
|
||||
|
||||
2. Run this file
|
||||
|
||||
./zipformer/onnx_pretrained_ctc.py \
|
||||
--nn-model /path/to/model.onnx \
|
||||
--tokens /path/to/data/lang_bpe_500/tokens.txt \
|
||||
1089-134686-0001.wav \
|
||||
1221-135766-0001.wav \
|
||||
1221-135766-0002.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List, Tuple
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nn-model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="""Path to tokens.txt.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(
|
||||
self,
|
||||
nn_model: str,
|
||||
):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.session_opts = session_opts
|
||||
|
||||
self.init_model(nn_model)
|
||||
|
||||
def init_model(self, nn_model: str):
|
||||
self.model = ort.InferenceSession(
|
||||
nn_model,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
meta = self.model.get_modelmeta().custom_metadata_map
|
||||
print(meta)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D float tensor of shape (N, T, C)
|
||||
x_lens:
|
||||
A 1-D int64 tensor of shape (N,)
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- A float tensor containing log_probs of shape (N, T, C)
|
||||
- A int64 tensor containing log_probs_len of shape (N)
|
||||
"""
|
||||
out = self.model.run(
|
||||
[
|
||||
self.model.get_outputs()[0].name,
|
||||
self.model.get_outputs()[1].name,
|
||||
],
|
||||
{
|
||||
self.model.get_inputs()[0].name: x.numpy(),
|
||||
self.model.get_inputs()[1].name: x_lens.numpy(),
|
||||
},
|
||||
)
|
||||
return torch.from_numpy(out[0]), torch.from_numpy(out[1])
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0].contiguous())
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
model = OnnxModel(
|
||||
nn_model=args.nn_model,
|
||||
)
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = "cpu"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = args.sample_rate
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=args.sound_files,
|
||||
expected_sample_rate=args.sample_rate,
|
||||
)
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
features = pad_sequence(
|
||||
features,
|
||||
batch_first=True,
|
||||
padding_value=math.log(1e-10),
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
|
||||
log_probs, log_probs_len = model(features, feature_lengths)
|
||||
|
||||
token_table = k2.SymbolTable.from_file(args.tokens)
|
||||
|
||||
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||
text = ""
|
||||
for i in token_ids:
|
||||
text += token_table[i]
|
||||
return text.replace("▁", " ").strip()
|
||||
|
||||
blank_id = 0
|
||||
s = "\n"
|
||||
for i in range(log_probs.size(0)):
|
||||
# greedy search
|
||||
indexes = log_probs[i, : log_probs_len[i]].argmax(dim=-1)
|
||||
token_ids = torch.unique_consecutive(indexes)
|
||||
|
||||
token_ids = token_ids[token_ids != blank_id]
|
||||
words = token_ids_to_words(token_ids.tolist())
|
||||
s += f"{args.sound_files[i]}:\n{words}\n\n"
|
||||
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
1
egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc.py
Symbolic link
1
egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/onnx_pretrained_ctc.py
|
||||
@ -1,277 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
"""
|
||||
This script loads ONNX models and uses them to decode waves.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Please follow ./export-onnx-ctc.py to get the onnx model.
|
||||
|
||||
2. Run this file
|
||||
|
||||
./zipformer/onnx_pretrained_ctc_H.py \
|
||||
--nn-model /path/to/model.onnx \
|
||||
--tokens /path/to/data/lang_bpe_500/tokens.txt \
|
||||
--H /path/to/H.fst \
|
||||
1089-134686-0001.wav \
|
||||
1221-135766-0001.wav \
|
||||
1221-135766-0002.wav
|
||||
|
||||
You can find exported ONNX models at
|
||||
https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List, Tuple
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
from typing import Dict
|
||||
import kaldifst
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torchaudio
|
||||
from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nn-model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="""Path to tokens.txt.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--H",
|
||||
type=str,
|
||||
help="""Path to H.fst.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(
|
||||
self,
|
||||
nn_model: str,
|
||||
):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.session_opts = session_opts
|
||||
|
||||
self.init_model(nn_model)
|
||||
|
||||
def init_model(self, nn_model: str):
|
||||
self.model = ort.InferenceSession(
|
||||
nn_model,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
meta = self.model.get_modelmeta().custom_metadata_map
|
||||
print(meta)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D float tensor of shape (N, T, C)
|
||||
x_lens:
|
||||
A 1-D int64 tensor of shape (N,)
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- A float tensor containing log_probs of shape (N, T, C)
|
||||
- A int64 tensor containing log_probs_len of shape (N)
|
||||
"""
|
||||
out = self.model.run(
|
||||
[
|
||||
self.model.get_outputs()[0].name,
|
||||
self.model.get_outputs()[1].name,
|
||||
],
|
||||
{
|
||||
self.model.get_inputs()[0].name: x.numpy(),
|
||||
self.model.get_inputs()[1].name: x_lens.numpy(),
|
||||
},
|
||||
)
|
||||
return torch.from_numpy(out[0]), torch.from_numpy(out[1])
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0].contiguous())
|
||||
return ans
|
||||
|
||||
|
||||
def decode(
|
||||
filename: str,
|
||||
log_probs: torch.Tensor,
|
||||
H: kaldifst,
|
||||
id2token: Dict[int, str],
|
||||
) -> List[str]:
|
||||
"""
|
||||
Args:
|
||||
filename:
|
||||
Path to the filename for decoding. Used for debugging.
|
||||
log_probs:
|
||||
A 2-D float32 tensor of shape (num_frames, vocab_size). It
|
||||
contains output from log_softmax.
|
||||
H:
|
||||
The H graph.
|
||||
id2word:
|
||||
A map mapping token ID to word string.
|
||||
Returns:
|
||||
Return a list of decoded words.
|
||||
"""
|
||||
logging.info(f"{filename}, {log_probs.shape}")
|
||||
decodable = DecodableCtc(log_probs.cpu())
|
||||
|
||||
decoder_opts = FasterDecoderOptions(max_active=3000)
|
||||
decoder = FasterDecoder(H, decoder_opts)
|
||||
decoder.decode(decodable)
|
||||
|
||||
if not decoder.reached_final():
|
||||
logging.info(f"failed to decode {filename}")
|
||||
return [""]
|
||||
|
||||
ok, best_path = decoder.get_best_path()
|
||||
|
||||
(
|
||||
ok,
|
||||
isymbols_out,
|
||||
osymbols_out,
|
||||
total_weight,
|
||||
) = kaldifst.get_linear_symbol_sequence(best_path)
|
||||
if not ok:
|
||||
logging.info(f"failed to get linear symbol sequence for {filename}")
|
||||
return [""]
|
||||
|
||||
# tokens are incremented during graph construction
|
||||
# are shifted by 1 during graph construction
|
||||
hyps = [id2token[i - 1] for i in osymbols_out if i != 1]
|
||||
hyps = "".join(hyps).split("\u2581") # unicode codepoint of ▁
|
||||
|
||||
return hyps
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
model = OnnxModel(
|
||||
nn_model=args.nn_model,
|
||||
)
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = "cpu"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = args.sample_rate
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
logging.info(f"Loading H from {args.H}")
|
||||
H = kaldifst.StdVectorFst.read(args.H)
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=args.sound_files,
|
||||
expected_sample_rate=args.sample_rate,
|
||||
)
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
features = pad_sequence(
|
||||
features,
|
||||
batch_first=True,
|
||||
padding_value=math.log(1e-10),
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
|
||||
log_probs, log_probs_len = model(features, feature_lengths)
|
||||
|
||||
token_table = k2.SymbolTable.from_file(args.tokens)
|
||||
|
||||
hyps = []
|
||||
for i in range(log_probs.shape[0]):
|
||||
hyp = decode(
|
||||
filename=args.sound_files[i],
|
||||
log_probs=log_probs[i, : log_probs_len[i]],
|
||||
H=H,
|
||||
id2token=token_table,
|
||||
)
|
||||
hyps.append(hyp)
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(args.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
1
egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_H.py
Symbolic link
1
egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_H.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py
|
||||
@ -1,275 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
"""
|
||||
This script loads ONNX models and uses them to decode waves.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Please follow ./export-onnx-ctc.py to get the onnx model.
|
||||
|
||||
2. Run this file
|
||||
|
||||
./zipformer/onnx_pretrained_ctc_HL.py \
|
||||
--nn-model /path/to/model.onnx \
|
||||
--words /path/to/data/lang_bpe_500/words.txt \
|
||||
--HL /path/to/HL.fst \
|
||||
1089-134686-0001.wav \
|
||||
1221-135766-0001.wav \
|
||||
1221-135766-0002.wav
|
||||
|
||||
You can find exported ONNX models at
|
||||
https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List, Tuple
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
from typing import Dict
|
||||
import kaldifst
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torchaudio
|
||||
from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nn-model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--words",
|
||||
type=str,
|
||||
help="""Path to words.txt.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--HL",
|
||||
type=str,
|
||||
help="""Path to HL.fst.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(
|
||||
self,
|
||||
nn_model: str,
|
||||
):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.session_opts = session_opts
|
||||
|
||||
self.init_model(nn_model)
|
||||
|
||||
def init_model(self, nn_model: str):
|
||||
self.model = ort.InferenceSession(
|
||||
nn_model,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
meta = self.model.get_modelmeta().custom_metadata_map
|
||||
print(meta)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D float tensor of shape (N, T, C)
|
||||
x_lens:
|
||||
A 1-D int64 tensor of shape (N,)
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- A float tensor containing log_probs of shape (N, T, C)
|
||||
- A int64 tensor containing log_probs_len of shape (N)
|
||||
"""
|
||||
out = self.model.run(
|
||||
[
|
||||
self.model.get_outputs()[0].name,
|
||||
self.model.get_outputs()[1].name,
|
||||
],
|
||||
{
|
||||
self.model.get_inputs()[0].name: x.numpy(),
|
||||
self.model.get_inputs()[1].name: x_lens.numpy(),
|
||||
},
|
||||
)
|
||||
return torch.from_numpy(out[0]), torch.from_numpy(out[1])
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0].contiguous())
|
||||
return ans
|
||||
|
||||
|
||||
def decode(
|
||||
filename: str,
|
||||
log_probs: torch.Tensor,
|
||||
HL: kaldifst,
|
||||
id2word: Dict[int, str],
|
||||
) -> List[str]:
|
||||
"""
|
||||
Args:
|
||||
filename:
|
||||
Path to the filename for decoding. Used for debugging.
|
||||
log_probs:
|
||||
A 2-D float32 tensor of shape (num_frames, vocab_size). It
|
||||
contains output from log_softmax.
|
||||
HL:
|
||||
The HL graph.
|
||||
id2word:
|
||||
A map mapping word ID to word string.
|
||||
Returns:
|
||||
Return a list of decoded words.
|
||||
"""
|
||||
logging.info(f"{filename}, {log_probs.shape}")
|
||||
decodable = DecodableCtc(log_probs.cpu())
|
||||
|
||||
decoder_opts = FasterDecoderOptions(max_active=3000)
|
||||
decoder = FasterDecoder(HL, decoder_opts)
|
||||
decoder.decode(decodable)
|
||||
|
||||
if not decoder.reached_final():
|
||||
logging.info(f"failed to decode {filename}")
|
||||
return [""]
|
||||
|
||||
ok, best_path = decoder.get_best_path()
|
||||
|
||||
(
|
||||
ok,
|
||||
isymbols_out,
|
||||
osymbols_out,
|
||||
total_weight,
|
||||
) = kaldifst.get_linear_symbol_sequence(best_path)
|
||||
if not ok:
|
||||
logging.info(f"failed to get linear symbol sequence for {filename}")
|
||||
return [""]
|
||||
|
||||
# are shifted by 1 during graph construction
|
||||
hyps = [id2word[i] for i in osymbols_out]
|
||||
|
||||
return hyps
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
model = OnnxModel(
|
||||
nn_model=args.nn_model,
|
||||
)
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = "cpu"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = args.sample_rate
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
logging.info(f"Loading HL from {args.HL}")
|
||||
HL = kaldifst.StdVectorFst.read(args.HL)
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=args.sound_files,
|
||||
expected_sample_rate=args.sample_rate,
|
||||
)
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
features = pad_sequence(
|
||||
features,
|
||||
batch_first=True,
|
||||
padding_value=math.log(1e-10),
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
|
||||
log_probs, log_probs_len = model(features, feature_lengths)
|
||||
|
||||
word_table = k2.SymbolTable.from_file(args.words)
|
||||
|
||||
hyps = []
|
||||
for i in range(log_probs.shape[0]):
|
||||
hyp = decode(
|
||||
filename=args.sound_files[i],
|
||||
log_probs=log_probs[i, : log_probs_len[i]],
|
||||
HL=HL,
|
||||
id2word=word_table,
|
||||
)
|
||||
hyps.append(hyp)
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(args.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
1
egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HL.py
Symbolic link
1
egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HL.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py
|
||||
@ -1,275 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
"""
|
||||
This script loads ONNX models and uses them to decode waves.
|
||||
|
||||
We use the pre-trained model from
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13
|
||||
as an example to show how to use this file.
|
||||
|
||||
1. Please follow ./export-onnx-ctc.py to get the onnx model.
|
||||
|
||||
2. Run this file
|
||||
|
||||
./zipformer/onnx_pretrained_ctc_HLG.py \
|
||||
--nn-model /path/to/model.onnx \
|
||||
--words /path/to/data/lang_bpe_500/words.txt \
|
||||
--HLG /path/to/HLG.fst \
|
||||
1089-134686-0001.wav \
|
||||
1221-135766-0001.wav \
|
||||
1221-135766-0002.wav
|
||||
|
||||
You can find exported ONNX models at
|
||||
https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List, Tuple
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
from typing import Dict
|
||||
import kaldifst
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torchaudio
|
||||
from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nn-model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the onnx model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--words",
|
||||
type=str,
|
||||
help="""Path to words.txt.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--HLG",
|
||||
type=str,
|
||||
help="""Path to HLG.fst.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(
|
||||
self,
|
||||
nn_model: str,
|
||||
):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.session_opts = session_opts
|
||||
|
||||
self.init_model(nn_model)
|
||||
|
||||
def init_model(self, nn_model: str):
|
||||
self.model = ort.InferenceSession(
|
||||
nn_model,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
meta = self.model.get_modelmeta().custom_metadata_map
|
||||
print(meta)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D float tensor of shape (N, T, C)
|
||||
x_lens:
|
||||
A 1-D int64 tensor of shape (N,)
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- A float tensor containing log_probs of shape (N, T, C)
|
||||
- A int64 tensor containing log_probs_len of shape (N)
|
||||
"""
|
||||
out = self.model.run(
|
||||
[
|
||||
self.model.get_outputs()[0].name,
|
||||
self.model.get_outputs()[1].name,
|
||||
],
|
||||
{
|
||||
self.model.get_inputs()[0].name: x.numpy(),
|
||||
self.model.get_inputs()[1].name: x_lens.numpy(),
|
||||
},
|
||||
)
|
||||
return torch.from_numpy(out[0]), torch.from_numpy(out[1])
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0].contiguous())
|
||||
return ans
|
||||
|
||||
|
||||
def decode(
|
||||
filename: str,
|
||||
log_probs: torch.Tensor,
|
||||
HLG: kaldifst,
|
||||
id2word: Dict[int, str],
|
||||
) -> List[str]:
|
||||
"""
|
||||
Args:
|
||||
filename:
|
||||
Path to the filename for decoding. Used for debugging.
|
||||
log_probs:
|
||||
A 2-D float32 tensor of shape (num_frames, vocab_size). It
|
||||
contains output from log_softmax.
|
||||
HLG:
|
||||
The HLG graph.
|
||||
id2word:
|
||||
A map mapping word ID to word string.
|
||||
Returns:
|
||||
Return a list of decoded words.
|
||||
"""
|
||||
logging.info(f"{filename}, {log_probs.shape}")
|
||||
decodable = DecodableCtc(log_probs.cpu())
|
||||
|
||||
decoder_opts = FasterDecoderOptions(max_active=3000)
|
||||
decoder = FasterDecoder(HLG, decoder_opts)
|
||||
decoder.decode(decodable)
|
||||
|
||||
if not decoder.reached_final():
|
||||
logging.info(f"failed to decode {filename}")
|
||||
return [""]
|
||||
|
||||
ok, best_path = decoder.get_best_path()
|
||||
|
||||
(
|
||||
ok,
|
||||
isymbols_out,
|
||||
osymbols_out,
|
||||
total_weight,
|
||||
) = kaldifst.get_linear_symbol_sequence(best_path)
|
||||
if not ok:
|
||||
logging.info(f"failed to get linear symbol sequence for {filename}")
|
||||
return [""]
|
||||
|
||||
# are shifted by 1 during graph construction
|
||||
hyps = [id2word[i] for i in osymbols_out]
|
||||
|
||||
return hyps
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
model = OnnxModel(
|
||||
nn_model=args.nn_model,
|
||||
)
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = "cpu"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = args.sample_rate
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
logging.info(f"Loading HLG from {args.HLG}")
|
||||
HLG = kaldifst.StdVectorFst.read(args.HLG)
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=args.sound_files,
|
||||
expected_sample_rate=args.sample_rate,
|
||||
)
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
features = pad_sequence(
|
||||
features,
|
||||
batch_first=True,
|
||||
padding_value=math.log(1e-10),
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
|
||||
log_probs, log_probs_len = model(features, feature_lengths)
|
||||
|
||||
word_table = k2.SymbolTable.from_file(args.words)
|
||||
|
||||
hyps = []
|
||||
for i in range(log_probs.shape[0]):
|
||||
hyp = decode(
|
||||
filename=args.sound_files[i],
|
||||
log_probs=log_probs[i, : log_probs_len[i]],
|
||||
HLG=HLG,
|
||||
id2word=word_table,
|
||||
)
|
||||
hyps.append(hyp)
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(args.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
1
egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py
Symbolic link
1
egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py
|
||||
@ -1,381 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads a checkpoint and uses it to decode waves.
|
||||
You can generate the checkpoint with the following command:
|
||||
|
||||
Note: This is a example for librispeech dataset, if you are using different
|
||||
dataset, you should change the argument values according to your dataset.
|
||||
|
||||
- For non-streaming model:
|
||||
|
||||
./zipformer/export.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 30 \
|
||||
--avg 9
|
||||
|
||||
- For streaming model:
|
||||
|
||||
./zipformer/export.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--causal 1 \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 30 \
|
||||
--avg 9
|
||||
|
||||
Usage of this script:
|
||||
|
||||
- For non-streaming model:
|
||||
|
||||
(1) greedy search
|
||||
./zipformer/pretrained.py \
|
||||
--checkpoint ./zipformer/exp/pretrained.pt \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--method greedy_search \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(2) modified beam search
|
||||
./zipformer/pretrained.py \
|
||||
--checkpoint ./zipformer/exp/pretrained.pt \
|
||||
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||
--method modified_beam_search \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(3) fast beam search
|
||||
./zipformer/pretrained.py \
|
||||
--checkpoint ./zipformer/exp/pretrained.pt \
|
||||
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||
--method fast_beam_search \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
- For streaming model:
|
||||
|
||||
(1) greedy search
|
||||
./zipformer/pretrained.py \
|
||||
--checkpoint ./zipformer/exp/pretrained.pt \
|
||||
--causal 1 \
|
||||
--chunk-size 16 \
|
||||
--left-context-frames 128 \
|
||||
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||
--method greedy_search \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(2) modified beam search
|
||||
./zipformer/pretrained.py \
|
||||
--checkpoint ./zipformer/exp/pretrained.pt \
|
||||
--causal 1 \
|
||||
--chunk-size 16 \
|
||||
--left-context-frames 128 \
|
||||
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||
--method modified_beam_search \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(3) fast beam search
|
||||
./zipformer/pretrained.py \
|
||||
--checkpoint ./zipformer/exp/pretrained.pt \
|
||||
--causal 1 \
|
||||
--chunk-size 16 \
|
||||
--left-context-frames 128 \
|
||||
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||
--method fast_beam_search \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
|
||||
You can also use `./zipformer/exp/epoch-xx.pt`.
|
||||
|
||||
Note: ./zipformer/exp/pretrained.pt is generated by ./zipformer/export.py
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import torch
|
||||
import torchaudio
|
||||
from beam_search import (
|
||||
fast_beam_search_one_best,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from export import num_tokens
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint. "
|
||||
"The checkpoint is assumed to be saved by "
|
||||
"icefall.checkpoint.save_checkpoint().",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="""Path to tokens.txt.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame. Used only when
|
||||
--method is greedy_search.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0].contiguous())
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||
|
||||
params.blank_id = token_table["<blk>"]
|
||||
params.unk_id = token_table["<unk>"]
|
||||
params.vocab_size = num_tokens(token_table) + 1
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
if params.causal:
|
||||
assert (
|
||||
"," not in params.chunk_size
|
||||
), "chunk_size should be one value in decoding."
|
||||
assert (
|
||||
"," not in params.left_context_frames
|
||||
), "left_context_frames should be one value in decoding."
|
||||
|
||||
logging.info("Creating model")
|
||||
model = get_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = params.sample_rate
|
||||
opts.mel_opts.num_bins = params.feature_dim
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {params.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
# model forward
|
||||
encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths)
|
||||
|
||||
hyps = []
|
||||
msg = f"Using {params.method}"
|
||||
logging.info(msg)
|
||||
|
||||
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||
text = ""
|
||||
for i in token_ids:
|
||||
text += token_table[i]
|
||||
return text.replace("▁", " ").strip()
|
||||
|
||||
if params.method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append(token_ids_to_words(hyp))
|
||||
elif params.method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append(token_ids_to_words(hyp))
|
||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append(token_ids_to_words(hyp))
|
||||
else:
|
||||
raise ValueError(f"Unsupported method: {params.method}")
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
s += f"{filename}:\n{hyp}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
1
egs/gigaspeech/ASR/zipformer/pretrained.py
Symbolic link
1
egs/gigaspeech/ASR/zipformer/pretrained.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/pretrained.py
|
||||
@ -1,455 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads a checkpoint and uses it to decode waves.
|
||||
You can generate the checkpoint with the following command:
|
||||
|
||||
- For non-streaming model:
|
||||
|
||||
./zipformer/export.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--use-ctc 1 \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 30 \
|
||||
--avg 9
|
||||
|
||||
- For streaming model:
|
||||
|
||||
./zipformer/export.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--use-ctc 1 \
|
||||
--causal 1 \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 30 \
|
||||
--avg 9
|
||||
|
||||
Usage of this script:
|
||||
|
||||
(1) ctc-decoding
|
||||
./zipformer/pretrained_ctc.py \
|
||||
--checkpoint ./zipformer/exp/pretrained.pt \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--method ctc-decoding \
|
||||
--sample-rate 16000 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(2) 1best
|
||||
./zipformer/pretrained_ctc.py \
|
||||
--checkpoint ./zipformer/exp/pretrained.pt \
|
||||
--HLG data/lang_bpe_500/HLG.pt \
|
||||
--words-file data/lang_bpe_500/words.txt \
|
||||
--method 1best \
|
||||
--sample-rate 16000 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(3) nbest-rescoring
|
||||
./zipformer/pretrained_ctc.py \
|
||||
--checkpoint ./zipformer/exp/pretrained.pt \
|
||||
--HLG data/lang_bpe_500/HLG.pt \
|
||||
--words-file data/lang_bpe_500/words.txt \
|
||||
--G data/lm/G_4_gram.pt \
|
||||
--method nbest-rescoring \
|
||||
--sample-rate 16000 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
|
||||
(4) whole-lattice-rescoring
|
||||
./zipformer/pretrained_ctc.py \
|
||||
--checkpoint ./zipformer/exp/pretrained.pt \
|
||||
--HLG data/lang_bpe_500/HLG.pt \
|
||||
--words-file data/lang_bpe_500/words.txt \
|
||||
--G data/lm/G_4_gram.pt \
|
||||
--method whole-lattice-rescoring \
|
||||
--sample-rate 16000 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import torch
|
||||
import torchaudio
|
||||
from ctc_decode import get_decoding_params
|
||||
from export import num_tokens
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
one_best_decoding,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.utils import get_texts
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint. "
|
||||
"The checkpoint is assumed to be saved by "
|
||||
"icefall.checkpoint.save_checkpoint().",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--words-file",
|
||||
type=str,
|
||||
help="""Path to words.txt.
|
||||
Used only when method is not ctc-decoding.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--HLG",
|
||||
type=str,
|
||||
help="""Path to HLG.pt.
|
||||
Used only when method is not ctc-decoding.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="""Path to tokens.txt.
|
||||
Used only when method is ctc-decoding.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="1best",
|
||||
help="""Decoding method.
|
||||
Possible values are:
|
||||
(0) ctc-decoding - Use CTC decoding. It uses a token table,
|
||||
i.e., lang_dir/tokens.txt, to convert
|
||||
word pieces to words. It needs neither a lexicon
|
||||
nor an n-gram LM.
|
||||
(1) 1best - Use the best path as decoding output. Only
|
||||
the transformer encoder output is used for decoding.
|
||||
We call it HLG decoding.
|
||||
(2) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||
rescore them with an LM, the path with
|
||||
the highest score is the decoding result.
|
||||
We call it HLG decoding + nbest n-gram LM rescoring.
|
||||
(3) whole-lattice-rescoring - Use an LM to rescore the
|
||||
decoding lattice and then use 1best to decode the
|
||||
rescored lattice.
|
||||
We call it HLG decoding + whole-lattice n-gram LM rescoring.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--G",
|
||||
type=str,
|
||||
help="""An LM for rescoring.
|
||||
Used only when method is
|
||||
whole-lattice-rescoring or nbest-rescoring.
|
||||
It's usually a 4-gram LM.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=100,
|
||||
help="""
|
||||
Used only when method is attention-decoder.
|
||||
It specifies the size of n-best list.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=1.3,
|
||||
help="""
|
||||
Used only when method is whole-lattice-rescoring and nbest-rescoring.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
(Note: You need to tune it on a dataset.)
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="""
|
||||
Used only when method is nbest-rescoring.
|
||||
It specifies the scale for lattice.scores when
|
||||
extracting n-best lists. A smaller value results in
|
||||
more unique number of paths with the risk of missing
|
||||
the best path.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float = 16000
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0].contiguous())
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
# add decoding params
|
||||
params.update(get_decoding_params())
|
||||
params.update(vars(args))
|
||||
|
||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||
params.vocab_size = num_tokens(token_table) + 1 # +1 for blank
|
||||
params.blank_id = token_table["<blk>"]
|
||||
assert params.blank_id == 0
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info("Creating model")
|
||||
model = get_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = params.sample_rate
|
||||
opts.mel_opts.num_bins = params.feature_dim
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {params.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths)
|
||||
ctc_output = model.ctc_output(encoder_out) # (N, T, C)
|
||||
|
||||
batch_size = ctc_output.shape[0]
|
||||
supervision_segments = torch.tensor(
|
||||
[
|
||||
[i, 0, feature_lengths[i].item() // params.subsampling_factor]
|
||||
for i in range(batch_size)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
if params.method == "ctc-decoding":
|
||||
logging.info("Use CTC decoding")
|
||||
max_token_id = params.vocab_size - 1
|
||||
|
||||
H = k2.ctc_topo(
|
||||
max_token=max_token_id,
|
||||
modified=False,
|
||||
device=device,
|
||||
)
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=ctc_output,
|
||||
decoding_graph=H,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
token_ids = get_texts(best_path)
|
||||
hyps = [[token_table[i] for i in ids] for ids in token_ids]
|
||||
elif params.method in [
|
||||
"1best",
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
]:
|
||||
logging.info(f"Loading HLG from {params.HLG}")
|
||||
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
||||
HLG = HLG.to(device)
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
# For whole-lattice-rescoring and attention-decoder
|
||||
HLG.lm_scores = HLG.scores.clone()
|
||||
|
||||
if params.method in [
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
]:
|
||||
logging.info(f"Loading G from {params.G}")
|
||||
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
||||
G = G.to(device)
|
||||
if params.method == "whole-lattice-rescoring":
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
# it with the whole lattice later
|
||||
G = k2.add_epsilon_self_loops(G)
|
||||
G = k2.arc_sort(G)
|
||||
|
||||
# G.lm_scores is used to replace HLG.lm_scores during
|
||||
# LM rescoring.
|
||||
G.lm_scores = G.scores.clone()
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=ctc_output,
|
||||
decoding_graph=HLG,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
if params.method == "1best":
|
||||
logging.info("Use HLG decoding")
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
if params.method == "nbest-rescoring":
|
||||
logging.info("Use HLG decoding + LM rescoring")
|
||||
best_path_dict = rescore_with_n_best_list(
|
||||
lattice=lattice,
|
||||
G=G,
|
||||
num_paths=params.num_paths,
|
||||
lm_scale_list=[params.ngram_lm_scale],
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
best_path = next(iter(best_path_dict.values()))
|
||||
elif params.method == "whole-lattice-rescoring":
|
||||
logging.info("Use HLG decoding + LM rescoring")
|
||||
best_path_dict = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=[params.ngram_lm_scale],
|
||||
)
|
||||
best_path = next(iter(best_path_dict.values()))
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
||||
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
||||
else:
|
||||
raise ValueError(f"Unsupported decoding method: {params.method}")
|
||||
|
||||
s = "\n"
|
||||
if params.method == "ctc-decoding":
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
words = "".join(hyp)
|
||||
words = words.replace("▁", " ").strip()
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
elif params.method in [
|
||||
"1best",
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
]:
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
words = words.replace("▁", " ").strip()
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
1
egs/gigaspeech/ASR/zipformer/pretrained_ctc.py
Symbolic link
1
egs/gigaspeech/ASR/zipformer/pretrained_ctc.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/pretrained_ctc.py
|
||||
Loading…
x
Reference in New Issue
Block a user