Add export script for the yesno recipe. (#1212)

This commit is contained in:
Fangjun Kuang 2023-08-11 23:57:00 +08:00 committed by GitHub
parent 74806b744b
commit d6b28a11a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 813 additions and 17 deletions

View File

@ -44,11 +44,6 @@ jobs:
with:
fetch-depth: 0
- name: Install graphviz
shell: bash
run: |
sudo apt-get -qq install graphviz
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
@ -70,6 +65,7 @@ jobs:
pip install --no-binary protobuf protobuf==3.20.*
pip install --no-deps --force-reinstall https://huggingface.co/csukuangfj/k2/resolve/main/cpu/k2-1.24.3.dev20230508+cpu.torch1.13.1-cp38-cp38-linux_x86_64.whl
pip install kaldifeat==1.25.0.dev20230726+cpu.torch1.13.1 -f https://csukuangfj.github.io/kaldifeat/cpu.html
- name: Run yesno recipe
shell: bash
@ -78,9 +74,75 @@ jobs:
export PYTHONPATH=$PWD:$PYTHONPATH
echo $PYTHONPATH
cd egs/yesno/ASR
./prepare.sh
python3 ./tdnn/train.py
python3 ./tdnn/decode.py
# TODO: Check that the WER is less than some value
- name: Test exporting to pretrained.pt
shell: bash
working-directory: ${{github.workspace}}
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
echo $PYTHONPATH
cd egs/yesno/ASR
python3 ./tdnn/export.py --epoch 14 --avg 2
python3 ./tdnn/pretrained.py \
--checkpoint ./tdnn/exp/pretrained.pt \
--HLG ./data/lang_phone/HLG.pt \
--words-file ./data/lang_phone/words.txt \
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
download/waves_yesno/0_0_1_0_0_0_1_0.wav
- name: Test exporting to torchscript
shell: bash
working-directory: ${{github.workspace}}
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
echo $PYTHONPATH
cd egs/yesno/ASR
python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1
python3 ./tdnn/jit_pretrained.py \
--nn-model ./tdnn/exp/cpu_jit.pt \
--HLG ./data/lang_phone/HLG.pt \
--words-file ./data/lang_phone/words.txt \
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
download/waves_yesno/0_0_1_0_0_0_1_0.wav
- name: Test exporting to onnx
shell: bash
working-directory: ${{github.workspace}}
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
echo $PYTHONPATH
cd egs/yesno/ASR
python3 ./tdnn/export_onnx.py --epoch 14 --avg 2
echo "Test float32 model"
python3 ./tdnn/onnx_pretrained.py \
--nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \
--HLG ./data/lang_phone/HLG.pt \
--words-file ./data/lang_phone/words.txt \
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
download/waves_yesno/0_0_1_0_0_0_1_0.wav
echo "Test int8 model"
python3 ./tdnn/onnx_pretrained.py \
--nn-model ./tdnn/exp/model-epoch-14-avg-2.int8.onnx \
--HLG ./data/lang_phone/HLG.pt \
--words-file ./data/lang_phone/words.txt \
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
download/waves_yesno/0_0_1_0_0_0_1_0.wav
- name: Show generated files
shell: bash
working-directory: ${{github.workspace}}
run: |
cd egs/yesno/ASR
ls -lh tdnn/exp

View File

@ -65,7 +65,6 @@ def get_params() -> AttributeDict:
{
"exp_dir": Path("tdnn/exp/"),
"lang_dir": Path("data/lang_phone"),
"lm_dir": Path("data/lm"),
"feature_dim": 23,
"search_beam": 20,
"output_beam": 8,

118
egs/yesno/ASR/tdnn/export.py Executable file
View File

@ -0,0 +1,118 @@
#!/usr/bin/env python3
"""
This file is for exporting trained models to a checkpoint
or to a torchscript model.
(1) Generate the checkpoint tdnn/exp/pretrained.pt
./tdnn/export.py \
--epoch 14 \
--avg 2
See ./tdnn/pretrained.py for how to use the generated file.
(2) Generate torchscript model tdnn/exp/cpu_jit.pt
./tdnn/export.py \
--epoch 14 \
--avg 2 \
--jit 1
See ./tdnn/jit_pretrained.py for how to use the generated file.
"""
import argparse
import logging
import torch
from model import Tdnn
from train import get_params
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=14,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=2,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
""",
)
return parser
@torch.no_grad()
def main():
args = get_parser().parse_args()
params = get_params()
params.update(vars(args))
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
model = Tdnn(
num_features=params.feature_dim,
num_classes=max_token_id + 1, # +1 for the blank symbol
)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
model.to("cpu")
model.eval()
if params.jit:
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

158
egs/yesno/ASR/tdnn/export_onnx.py Executable file
View File

@ -0,0 +1,158 @@
#!/usr/bin/env python3
"""
This file is for exporting trained models to onnx.
Usage:
./tdnn/export_onnx.py \
--epoch 14 \
--avg 2
The above command generates the following two files:
- ./exp/model-epoch-14-avg-2.onnx
- ./exp/model-epoch-14-avg-2.int8.onnx
See ./tdnn/onnx_pretrained.py for how to use them.
"""
import argparse
import logging
from typing import Dict
import onnx
import torch
from model import Tdnn
from onnxruntime.quantization import QuantType, quantize_dynamic
from train import get_params
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=14,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=2,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
return parser
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
onnx.save(model, filename)
@torch.no_grad()
def main():
args = get_parser().parse_args()
params = get_params()
params.update(vars(args))
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
model = Tdnn(
num_features=params.feature_dim,
num_classes=max_token_id + 1, # +1 for the blank symbol
)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
model.to("cpu")
model.eval()
N = 1
T = 100
C = params.feature_dim
x = torch.rand(N, T, C)
opset_version = 13
onnx_filename = f"{params.exp_dir}/model-epoch-{params.epoch}-avg-{params.avg}.onnx"
torch.onnx.export(
model,
x,
onnx_filename,
verbose=False,
opset_version=opset_version,
input_names=["x"],
output_names=["log_prob"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"log_prob": {0: "N", 1: "T"},
},
)
logging.info(f"Saved to {onnx_filename}")
meta_data = {
"model_type": "tdnn_lstm",
"version": "1",
"model_author": "k2-fsa",
"comment": "non-streaming tdnn for the yesno recipe",
"vocab_size": max_token_id + 1,
}
logging.info(f"meta_data: {meta_data}")
add_meta_data(filename=onnx_filename, meta_data=meta_data)
logging.info("Generate int8 quantization models")
onnx_filename_int8 = (
f"{params.exp_dir}/model-epoch-{params.epoch}-avg-{params.avg}.int8.onnx"
)
quantize_dynamic(
model_input=onnx_filename,
model_output=onnx_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
logging.info(f"Saved to {onnx_filename_int8}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,199 @@
#!/usr/bin/env python3
"""
This file shows how to use a torchscript model for decoding.
Usage:
./tdnn/jit_pretrained.py \
--nn-model ./tdnn/exp/cpu_jit.pt \
--HLG ./data/lang_phone/HLG.pt \
--words-file ./data/lang_phone/words.txt \
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
download/waves_yesno/0_0_1_0_0_0_1_0.wav
Note that to generate ./tdnn/exp/cpu_jit.pt,
you can use ./export.py --jit 1
"""
import argparse
import logging
from typing import List
import math
import k2
import kaldifeat
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
from icefall.decode import get_lattice, one_best_decoding
from icefall.utils import AttributeDict, get_texts
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--nn-model",
type=str,
required=True,
help="""Path to the torchscript model.
You can use ./tdnn/export.py --jit 1
to obtain it
""",
)
parser.add_argument(
"--words-file",
type=str,
required=True,
help="Path to words.txt",
)
parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. ",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"feature_dim": 23,
"num_classes": 4, # [<blk>, N, SIL, Y]
"sample_rate": 8000,
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
}
)
return params
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
if sample_rate != expected_sample_rate:
wave = torchaudio.functional.resample(
wave,
orig_freq=sample_rate,
new_freq=expected_sample_rate,
)
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Loading torchscript model")
model = torch.jit.load(args.nn_model)
model.eval()
model.to(device)
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = HLG.to(device)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
# Note: We don't use key padding mask for attention during decoding
nnet_output = model(features)
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
dtype=torch.int32,
)
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
)
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,241 @@
#!/usr/bin/env python3
"""
This file shows how to use an ONNX model for decoding with onnxruntime.
Usage:
(1) Use a not quantized ONNX model, i.e., a float32 model
./tdnn/onnx_pretrained.py \
--nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \
--HLG ./data/lang_phone/HLG.pt \
--words-file ./data/lang_phone/words.txt \
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
download/waves_yesno/0_0_1_0_0_0_1_0.wav
(2) Use a quantized ONNX model, i.e., an int8 model
./tdnn/onnx_pretrained.py \
--nn-model ./tdnn/exp/model-epoch-14-avg-2.int8.onnx \
--HLG ./data/lang_phone/HLG.pt \
--words-file ./data/lang_phone/words.txt \
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
download/waves_yesno/0_0_1_0_0_0_1_0.wav
Note that to generate ./tdnn/exp/model-epoch-14-avg-2.onnx,
and ./tdnn/exp/model-epoch-14-avg-2.onnx,
you can use ./export_onnx.py --epoch 14 --avg 2
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import onnxruntime as ort
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
from icefall.decode import get_lattice, one_best_decoding
from icefall.utils import AttributeDict, get_texts
class OnnxModel:
def __init__(self, nn_model: str):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.model = ort.InferenceSession(
nn_model,
sess_options=self.session_opts,
)
meta = self.model.get_modelmeta().custom_metadata_map
self.vocab_size = int(meta["vocab_size"])
def run(
self,
x: torch.Tensor,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C)
Returns:
Return a 3-D tensor log_prob of shape (N, T, C)
"""
out = self.model.run(
[
self.model.get_outputs()[0].name,
],
{
self.model.get_inputs()[0].name: x.numpy(),
},
)
return torch.from_numpy(out[0])
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--nn-model",
type=str,
required=True,
help="""Path to the torchscript model.
You can use ./tdnn/export.py --jit 1
to obtain it
""",
)
parser.add_argument(
"--words-file",
type=str,
required=True,
help="Path to words.txt",
)
parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. ",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
if sample_rate != expected_sample_rate:
wave = torchaudio.functional.resample(
wave,
orig_freq=sample_rate,
new_freq=expected_sample_rate,
)
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
def get_params() -> AttributeDict:
params = AttributeDict(
{
"feature_dim": 23,
"sample_rate": 8000,
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
}
)
return params
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info(f"Loading onnx model {params.nn_model}")
model = OnnxModel(params.nn_model)
logging.info(f"Loading HLG from {args.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = HLG.to(device)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
# Note: We don't use key padding mask for attention during decoding
nnet_output = model.run(features)
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
dtype=torch.int32,
)
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
)
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -15,6 +15,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file shows how to use a checkpoint for decoding.
Usage:
./tdnn/pretrained.py \
--checkpoint ./tdnn/exp/pretrained.pt \
--HLG ./data/lang_phone/HLG.pt \
--words-file ./data/lang_phone/words.txt \
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
download/waves_yesno/0_0_1_0_0_0_1_0.wav
Note that to generate ./tdnn/exp/pretrained.pt,
you can use ./export.py
"""
import argparse
import logging
@ -43,7 +58,8 @@ def get_parser():
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
"icefall.checkpoint.save_checkpoint(). "
"You can use ./tdnn/export.py to obtain it.",
)
parser.add_argument(
@ -61,8 +77,7 @@ def get_parser():
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
"For example, wav and flac are supported. ",
)
return parser
@ -99,14 +114,19 @@ def read_sound_files(
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
if sample_rate != expected_sample_rate:
wave = torchaudio.functional.resample(
wave,
orig_freq=sample_rate,
new_freq=expected_sample_rate,
)
# We use only the first channel
ans.append(wave[0])
ans.append(wave[0].contiguous())
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
@ -159,8 +179,7 @@ def main():
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
# Note: We don't use key padding mask for attention during decoding
with torch.no_grad():
nnet_output = model(features)
nnet_output = model(features)
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(