mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Support exporting RNNLM to ONNX. (#1014)
* Support exporting RNNLM to ONNX. * add int8 models * fix style issues * Fix EOS padding * support exporting for streaming ASR
This commit is contained in:
parent
45c13e90e4
commit
2767b9ff11
1
icefall/rnn_lm/.gitignore
vendored
Normal file
1
icefall/rnn_lm/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
icefall-librispeech-rnn-lm
|
132
icefall/rnn_lm/check-onnx-streaming.py
Executable file
132
icefall/rnn_lm/check-onnx-streaming.py
Executable file
@ -0,0 +1,132 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation
|
||||
|
||||
"""
|
||||
Usage:
|
||||
|
||||
./check-onnx-streaming.py \
|
||||
--jit ./icefall-librispeech-rnn-lm/exp/cpu_jit.pt \
|
||||
--onnx ./icefall-librispeech-rnn-lm/exp/with-state-epoch-99-avg-1.onnx
|
||||
|
||||
Note: You can download pre-trained models from
|
||||
https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the torchscript model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the onnx model",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(self, filename: str):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.model = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=session_opts,
|
||||
)
|
||||
|
||||
meta_data = self.model.get_modelmeta().custom_metadata_map
|
||||
self.sos_id = int(meta_data["sos_id"])
|
||||
self.eos_id = int(meta_data["eos_id"])
|
||||
self.vocab_size = int(meta_data["vocab_size"])
|
||||
self.num_layers = int(meta_data["num_layers"])
|
||||
self.hidden_size = int(meta_data["hidden_size"])
|
||||
print(meta_data)
|
||||
|
||||
def __call__(
|
||||
self, x: torch.Tensor, y: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
out = self.model.run(
|
||||
[
|
||||
self.model.get_outputs()[0].name,
|
||||
self.model.get_outputs()[1].name,
|
||||
self.model.get_outputs()[2].name,
|
||||
],
|
||||
{
|
||||
self.model.get_inputs()[0].name: x.numpy(),
|
||||
self.model.get_inputs()[1].name: y.numpy(),
|
||||
self.model.get_inputs()[2].name: h0.numpy(),
|
||||
self.model.get_inputs()[3].name: c0.numpy(),
|
||||
},
|
||||
)
|
||||
return (
|
||||
torch.from_numpy(out[0]),
|
||||
torch.from_numpy(out[1]),
|
||||
torch.from_numpy(out[2]),
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
torch_model = torch.jit.load(args.jit).cpu()
|
||||
onnx_model = OnnxModel(args.onnx)
|
||||
N = torch.arange(1, 5).tolist()
|
||||
|
||||
num_layers = onnx_model.num_layers
|
||||
hidden_size = onnx_model.hidden_size
|
||||
|
||||
for n in N:
|
||||
L = torch.randint(low=1, high=100, size=(1,)).item()
|
||||
x = torch.randint(
|
||||
low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64
|
||||
)
|
||||
y = torch.randint(
|
||||
low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64
|
||||
)
|
||||
h0 = torch.rand(num_layers, n, hidden_size)
|
||||
c0 = torch.rand(num_layers, n, hidden_size)
|
||||
|
||||
torch_nll, torch_h0, torch_c0 = torch_model.streaming_forward(x, y, h0, c0)
|
||||
onnx_nll, onnx_h0, onnx_c0 = onnx_model(x, y, h0, c0)
|
||||
|
||||
for torch_v, onnx_v in zip(
|
||||
(torch_nll, torch_h0, torch_c0), (onnx_nll, onnx_h0, onnx_c0)
|
||||
):
|
||||
|
||||
assert torch.allclose(torch_v, onnx_v, atol=1e-5), (
|
||||
torch_v.shape,
|
||||
onnx_v.shape,
|
||||
(torch_v - onnx_v).abs().max(),
|
||||
)
|
||||
print(n, L, torch_v.sum(), onnx_v.sum())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20230423)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
119
icefall/rnn_lm/check-onnx.py
Executable file
119
icefall/rnn_lm/check-onnx.py
Executable file
@ -0,0 +1,119 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation
|
||||
|
||||
"""
|
||||
Usage:
|
||||
|
||||
./check-onnx.py \
|
||||
--jit ./icefall-librispeech-rnn-lm/exp/cpu_jit.pt \
|
||||
--onnx ./icefall-librispeech-rnn-lm/exp/no-state-epoch-99-avg-1.onnx
|
||||
|
||||
Note: You can download pre-trained models from
|
||||
https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the torchscript model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the onnx model",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(self, filename: str):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.model = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=session_opts,
|
||||
)
|
||||
|
||||
meta_data = self.model.get_modelmeta().custom_metadata_map
|
||||
self.sos_id = int(meta_data["sos_id"])
|
||||
self.eos_id = int(meta_data["eos_id"])
|
||||
self.vocab_size = int(meta_data["vocab_size"])
|
||||
print(meta_data)
|
||||
|
||||
def __call__(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
|
||||
out = self.model.run(
|
||||
[
|
||||
self.model.get_outputs()[0].name,
|
||||
],
|
||||
{
|
||||
self.model.get_inputs()[0].name: x.numpy(),
|
||||
self.model.get_inputs()[1].name: x_lens.numpy(),
|
||||
},
|
||||
)
|
||||
return torch.from_numpy(out[0])
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
torch_model = torch.jit.load(args.jit).cpu()
|
||||
onnx_model = OnnxModel(args.onnx)
|
||||
N = torch.arange(1, 5).tolist()
|
||||
|
||||
for n in N:
|
||||
L = torch.randint(low=1, high=100, size=(1,)).item()
|
||||
x = torch.randint(
|
||||
low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64
|
||||
)
|
||||
x_lens = torch.full((n,), fill_value=L, dtype=torch.int64)
|
||||
if n > 1:
|
||||
x_lens[0] = L // 2 + 1
|
||||
|
||||
sos = torch.full((1,), fill_value=onnx_model.sos_id).expand(n, 1)
|
||||
sos_x = torch.cat([sos, x], dim=1)
|
||||
|
||||
pad_col = torch.zeros((1,), dtype=x.dtype).expand(n, 1)
|
||||
x_eos = torch.cat([x, pad_col], dim=1)
|
||||
|
||||
row_index = torch.arange(0, n, dtype=x.dtype)
|
||||
x_eos[row_index, x_lens] = onnx_model.eos_id
|
||||
|
||||
torch_nll = torch_model(sos_x, x_eos, x_lens + 1).sum(dim=-1)
|
||||
onnx_nll = onnx_model(x, x_lens)
|
||||
# Note: For int8 models, the differences may be quite large,
|
||||
# e.g., within 0.9
|
||||
assert torch.allclose(torch_nll, onnx_nll), (
|
||||
torch_nll,
|
||||
onnx_nll,
|
||||
)
|
||||
print(n, L, torch_nll, onnx_nll)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20230420)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
395
icefall/rnn_lm/export-onnx.py
Executable file
395
icefall/rnn_lm/export-onnx.py
Executable file
@ -0,0 +1,395 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
from model import RnnLmModel
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
||||
from icefall.utils import AttributeDict, str2bool
|
||||
from typing import Dict
|
||||
from train import get_params
|
||||
|
||||
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||
"""Add meta data to an ONNX model. It is changed in-place.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the ONNX model to be changed.
|
||||
meta_data:
|
||||
Key-value pairs.
|
||||
"""
|
||||
model = onnx.load(filename)
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = value
|
||||
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
# A wrapper for RnnLm model to simpily the C++ calling code
|
||||
# when exporting the model to ONNX.
|
||||
#
|
||||
# TODO(fangjun): The current wrapper works only for non-streaming ASR
|
||||
# since we don't expose the LM state and it is used to score
|
||||
# a complete sentence at once.
|
||||
class RnnLmModelWrapper(torch.nn.Module):
|
||||
def __init__(self, model: RnnLmModel, sos_id: int, eos_id: int):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.sos_id = sos_id
|
||||
self.eos_id = eos_id
|
||||
|
||||
def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 2-D tensor of shape (N, L) with dtype torch.int64.
|
||||
It does not contain SOS or EOS. We will add SOS and EOS inside
|
||||
this function.
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,) with dtype torch.int64. It contains
|
||||
number of valid tokens in ``x`` before padding.
|
||||
Returns:
|
||||
Return a 1-D tensor of shape (N,) containing negative loglikelihood.
|
||||
Its dtype is torch.float32
|
||||
"""
|
||||
N = x.size(0)
|
||||
|
||||
sos_tensor = torch.full((1,), fill_value=self.sos_id, dtype=x.dtype).expand(
|
||||
N, 1
|
||||
)
|
||||
sos_x = torch.cat([sos_tensor, x], dim=1)
|
||||
|
||||
pad_col = torch.zeros((1,), dtype=x.dtype).expand(N, 1)
|
||||
x_eos = torch.cat([x, pad_col], dim=1)
|
||||
|
||||
row_index = torch.arange(0, N, dtype=x.dtype)
|
||||
x_eos[row_index, x_lens] = self.eos_id
|
||||
|
||||
# use x_lens + 1 here since we prepended x with sos
|
||||
return (
|
||||
self.model(x=sos_x, y=x_eos, lengths=x_lens + 1)
|
||||
.to(torch.float32)
|
||||
.sum(dim=1)
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=29,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vocab-size",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Vocabulary size of the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--embedding-dim",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Embedding dim of the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hidden-dim",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Hidden dim of the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-layers",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of RNN layers the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tie-weights",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="""True to share the weights between the input embedding layer and the
|
||||
last output linear layer
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="rnn_lm/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def export_without_state(
|
||||
model: RnnLmModel,
|
||||
filename: str,
|
||||
params: AttributeDict,
|
||||
opset_version: int,
|
||||
):
|
||||
model_wrapper = RnnLmModelWrapper(
|
||||
model,
|
||||
sos_id=params.sos_id,
|
||||
eos_id=params.eos_id,
|
||||
)
|
||||
|
||||
N = 1
|
||||
L = 20
|
||||
x = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64)
|
||||
x_lens = torch.full((N,), fill_value=L, dtype=torch.int64)
|
||||
|
||||
# Note(fangjun): The following warnings can be ignored.
|
||||
# We can use ./check-onnx.py to validate the exported model with batch_size > 1
|
||||
"""
|
||||
torch/onnx/symbolic_opset9.py:2119: UserWarning: Exporting a model to ONNX
|
||||
with a batch_size other than 1, with a variable length with LSTM can cause
|
||||
an error when running the ONNX model with a different batch size. Make sure
|
||||
to save the model with a batch size of 1, or define the initial states
|
||||
(h0/c0) as inputs of the model. warnings.warn("Exporting a model to ONNX
|
||||
with a batch_size other than 1, " +
|
||||
"""
|
||||
|
||||
torch.onnx.export(
|
||||
model_wrapper,
|
||||
(x, x_lens),
|
||||
filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x", "x_lens"],
|
||||
output_names=["nll"],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "L"},
|
||||
"x_lens": {0: "N"},
|
||||
"nll": {0: "N"},
|
||||
},
|
||||
)
|
||||
|
||||
meta_data = {
|
||||
"model_type": "rnnlm",
|
||||
"version": "1",
|
||||
"model_author": "k2-fsa",
|
||||
"comment": "rnnlm without state",
|
||||
"sos_id": str(params.sos_id),
|
||||
"eos_id": str(params.eos_id),
|
||||
"vocab_size": str(params.vocab_size),
|
||||
"url": "https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm",
|
||||
}
|
||||
logging.info(f"meta_data: {meta_data}")
|
||||
|
||||
add_meta_data(filename=filename, meta_data=meta_data)
|
||||
|
||||
|
||||
def export_with_state(
|
||||
model: RnnLmModel,
|
||||
filename: str,
|
||||
params: AttributeDict,
|
||||
opset_version: int,
|
||||
):
|
||||
N = 1
|
||||
L = 20
|
||||
num_layers = model.rnn.num_layers
|
||||
hidden_size = model.rnn.hidden_size
|
||||
|
||||
x = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64)
|
||||
y = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64)
|
||||
h0 = torch.zeros(num_layers, N, hidden_size)
|
||||
c0 = torch.zeros(num_layers, N, hidden_size)
|
||||
|
||||
# Note(fangjun): The following warnings can be ignored.
|
||||
# We can use ./check-onnx.py to validate the exported model with batch_size > 1
|
||||
"""
|
||||
torch/onnx/symbolic_opset9.py:2119: UserWarning: Exporting a model to ONNX
|
||||
with a batch_size other than 1, with a variable length with LSTM can cause
|
||||
an error when running the ONNX model with a different batch size. Make sure
|
||||
to save the model with a batch size of 1, or define the initial states
|
||||
(h0/c0) as inputs of the model. warnings.warn("Exporting a model to ONNX
|
||||
with a batch_size other than 1, " +
|
||||
"""
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(x, y, h0, c0),
|
||||
filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x", "y", "h0", "c0"],
|
||||
output_names=["nll", "next_h0", "next_c0"],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "L"},
|
||||
"y": {0: "N", 1: "L"},
|
||||
"h0": {1: "N"},
|
||||
"c0": {1: "N"},
|
||||
"nll": {0: "N"},
|
||||
"next_h0": {1: "N"},
|
||||
"next_c0": {1: "N"},
|
||||
},
|
||||
)
|
||||
|
||||
meta_data = {
|
||||
"model_type": "rnnlm",
|
||||
"version": "1",
|
||||
"model_author": "k2-fsa",
|
||||
"comment": "rnnlm state",
|
||||
"sos_id": str(params.sos_id),
|
||||
"eos_id": str(params.eos_id),
|
||||
"vocab_size": str(params.vocab_size),
|
||||
"num_layers": str(num_layers),
|
||||
"hidden_size": str(hidden_size),
|
||||
"url": "https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm",
|
||||
}
|
||||
logging.info(f"meta_data: {meta_data}")
|
||||
|
||||
add_meta_data(filename=filename, meta_data=meta_data)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
logging.info(params)
|
||||
|
||||
device = torch.device("cpu")
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
model = RnnLmModel(
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.embedding_dim,
|
||||
hidden_dim=params.hidden_dim,
|
||||
num_layers=params.num_layers,
|
||||
tie_weights=params.tie_weights,
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device), strict=False
|
||||
)
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device), strict=False
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
if params.iter > 0:
|
||||
suffix = f"iter-{params.iter}"
|
||||
else:
|
||||
suffix = f"epoch-{params.epoch}"
|
||||
|
||||
suffix += f"-avg-{params.avg}"
|
||||
|
||||
opset_version = 13
|
||||
|
||||
logging.info("Exporting model without state")
|
||||
filename = params.exp_dir / f"no-state-{suffix}.onnx"
|
||||
export_without_state(
|
||||
model=model,
|
||||
filename=filename,
|
||||
params=params,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
filename_int8 = params.exp_dir / f"no-state-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=filename,
|
||||
model_output=filename_int8,
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
# now for streaming export
|
||||
saved_forward = model.__class__.forward
|
||||
model.__class__.forward = model.__class__.streaming_forward
|
||||
streaming_filename = params.exp_dir / f"with-state-{suffix}.onnx"
|
||||
export_with_state(
|
||||
model=model,
|
||||
filename=streaming_filename,
|
||||
params=params,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
model.__class__.forward = saved_forward
|
||||
|
||||
streaming_filename_int8 = params.exp_dir / f"with-state-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=streaming_filename,
|
||||
model_output=streaming_filename_int8,
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
26
icefall/rnn_lm/export-onnx.sh
Executable file
26
icefall/rnn_lm/export-onnx.sh
Executable file
@ -0,0 +1,26 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# We use the model from
|
||||
# https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main
|
||||
# as an example
|
||||
|
||||
export CUDA_VISIBLE_DEVICES=
|
||||
|
||||
if [ ! -f ./icefall-librispeech-rnn-lm/exp/pretrained.pt ]; then
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
|
||||
pushd icefall-librispeech-rnn-lm/exp
|
||||
git lfs pull --include "pretrained.pt"
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
fi
|
||||
|
||||
python3 ./export-onnx.py \
|
||||
--exp-dir ./icefall-librispeech-rnn-lm/exp \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--vocab-size 500 \
|
||||
--embedding-dim 2048 \
|
||||
--hidden-dim 2048 \
|
||||
--num-layers 3 \
|
||||
--tie-weights 1
|
||||
|
@ -26,7 +26,7 @@ import torch
|
||||
from model import RnnLmModel
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
||||
from icefall.utils import AttributeDict, load_averaged_model, str2bool
|
||||
from icefall.utils import AttributeDict, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -118,6 +118,7 @@ def get_parser():
|
||||
return parser
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
@ -180,6 +181,10 @@ def main():
|
||||
|
||||
if params.jit:
|
||||
logging.info("Using torch.jit.script")
|
||||
|
||||
model.__class__.streaming_forward = torch.jit.export(
|
||||
model.__class__.streaming_forward
|
||||
)
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
model.save(str(filename))
|
||||
|
27
icefall/rnn_lm/export.sh
Executable file
27
icefall/rnn_lm/export.sh
Executable file
@ -0,0 +1,27 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# We use the model from
|
||||
# https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main
|
||||
# as an example
|
||||
|
||||
export CUDA_VISIBLE_DEVICES=
|
||||
|
||||
if [ ! -f ./icefall-librispeech-rnn-lm/exp/pretrained.pt ]; then
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
|
||||
pushd icefall-librispeech-rnn-lm/exp
|
||||
git lfs pull --include "pretrained.pt"
|
||||
ln -s pretrained.pt epoch-99.pt
|
||||
popd
|
||||
fi
|
||||
|
||||
python3 ./export.py \
|
||||
--exp-dir ./icefall-librispeech-rnn-lm/exp \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--vocab-size 500 \
|
||||
--embedding-dim 2048 \
|
||||
--hidden-dim 2048 \
|
||||
--num-layers 3 \
|
||||
--tie-weights 1 \
|
||||
--jit 1
|
||||
|
@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -47,6 +48,11 @@ class RnnLmModel(torch.nn.Module):
|
||||
and https://arxiv.org/abs/1611.01462
|
||||
"""
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.embedding_dim = embedding_dim
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_layers = num_layers
|
||||
self.tie_weights = tie_weights
|
||||
|
||||
self.input_embedding = torch.nn.Embedding(
|
||||
num_embeddings=vocab_size,
|
||||
@ -74,6 +80,46 @@ class RnnLmModel(torch.nn.Module):
|
||||
|
||||
self.cache = {}
|
||||
|
||||
def streaming_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
h0: torch.Tensor,
|
||||
c0: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
'''
|
||||
Args:
|
||||
x:
|
||||
A 2-D tensor of shape (N, L). We won't prepend it with SOS.
|
||||
y:
|
||||
A 2-D tensor of shape (N, L). We won't append it with EOS.
|
||||
h0:
|
||||
A 3-D tensor of shape (num_layers, N, hidden_size).
|
||||
(If proj_size > 0, then it is (num_layers, N, proj_size))
|
||||
c0:
|
||||
A 3-D tensor of shape (num_layers, N, hidden_size).
|
||||
Returns:
|
||||
Return a tuple containing 3 tensors:
|
||||
- negative loglike (nll), a 1-D tensor of shape (N,)
|
||||
- next_h0, a 3-D tensor with the same shape as h0
|
||||
- next_c0, a 3-D tensor with the same shape as c0
|
||||
'''
|
||||
assert x.ndim == y.ndim == 2, (x.ndim, y.ndim)
|
||||
assert x.shape == y.shape, (x.shape, y.shape)
|
||||
|
||||
# embedding is of shape (N, L, embedding_dim)
|
||||
embedding = self.input_embedding(x)
|
||||
# Note: We use batch_first==True
|
||||
rnn_out, (next_h0, next_c0) = self.rnn(embedding, (h0, c0))
|
||||
logits = self.output_linear(rnn_out)
|
||||
nll_loss = F.cross_entropy(
|
||||
logits.reshape(-1, self.vocab_size), y.reshape(-1), reduction="none"
|
||||
)
|
||||
|
||||
batch_size = x.size(0)
|
||||
nll_loss = nll_loss.reshape(batch_size, -1).sum(dim=1)
|
||||
return nll_loss, next_h0, next_c0
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, y: torch.Tensor, lengths: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
|
Loading…
x
Reference in New Issue
Block a user