mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
add using averaged model in export.py
This commit is contained in:
parent
45c7894111
commit
522a45ce75
@ -356,7 +356,7 @@ def decode_one_batch(
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
encoder_out, encoder_out_lens, _ = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
|
||||
|
@ -34,25 +34,7 @@ Usage:
|
||||
It will generates 3 files: `encoder_jit_trace.pt`,
|
||||
`decoder_jit_trace.pt`, and `joiner_jit_trace.pt`.
|
||||
|
||||
|
||||
(3) Export to ONNX format
|
||||
|
||||
./lstm_transducer_stateless/export.py \
|
||||
--exp-dir ./lstm_transducer_stateless/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--onnx 1
|
||||
|
||||
It will generate the following three files in the given `exp_dir`.
|
||||
Check `onnx_check.py` for how to use them.
|
||||
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
- joiner.onnx
|
||||
|
||||
|
||||
(4) Export `model.state_dict()`
|
||||
(2) Export `model.state_dict()`
|
||||
|
||||
./lstm_transducer_stateless/export.py \
|
||||
--exp-dir ./lstm_transducer_stateless/exp \
|
||||
@ -97,7 +79,6 @@ import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import onnx
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -106,6 +87,7 @@ from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
@ -145,6 +127,17 @@ def get_parser():
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
@ -175,21 +168,6 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""If True, --jit is ignored and it exports the model
|
||||
to onnx format. Three files will be generated:
|
||||
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
- joiner.onnx
|
||||
|
||||
Check ./onnx_check.py and ./onnx_pretrained.py for how to use them.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
@ -220,7 +198,6 @@ def export_encoder_model_jit_trace(
|
||||
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||
states = encoder_model.get_init_states()
|
||||
states = (states[0].unsqueeze(1), states[1].unsqueeze(1))
|
||||
|
||||
traced_model = torch.jit.trace(encoder_model, (x, x_lens, states))
|
||||
traced_model.save(encoder_filename)
|
||||
@ -276,187 +253,6 @@ def export_joiner_model_jit_trace(
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
|
||||
def export_encoder_model_onnx(
|
||||
encoder_model: nn.Module,
|
||||
encoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the given encoder model to ONNX format.
|
||||
The exported model has two inputs:
|
||||
|
||||
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
||||
- x_lens, a tensor of shape (N,); dtype is torch.int64
|
||||
|
||||
and it has two outputs:
|
||||
|
||||
- encoder_out, a tensor of shape (N, T, C)
|
||||
- encoder_out_lens, a tensor of shape (N,)
|
||||
|
||||
Note: The warmup argument is fixed to 1.
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
The input encoder model
|
||||
encoder_filename:
|
||||
The filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||
states = encoder_model.get_init_states()
|
||||
hidden_states = states[0].unsqueeze(1)
|
||||
cell_states = states[1].unsqueeze(1)
|
||||
# encoder_model = torch.jit.script(encoder_model)
|
||||
# It throws the following error for the above statement
|
||||
#
|
||||
# RuntimeError: Exporting the operator __is_ to ONNX opset version
|
||||
# 11 is not supported. Please feel free to request support or
|
||||
# submit a pull request on PyTorch GitHub.
|
||||
#
|
||||
# I cannot find which statement causes the above error.
|
||||
# torch.onnx.export() will use torch.jit.trace() internally, which
|
||||
# works well for the current reworked model
|
||||
warmup = 1.0
|
||||
torch.onnx.export(
|
||||
encoder_model,
|
||||
(x, x_lens, (hidden_states, cell_states), warmup),
|
||||
encoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x", "x_lens", "hidden_states", "cell_states", "warmup"],
|
||||
output_names=[
|
||||
"encoder_out",
|
||||
"encoder_out_lens",
|
||||
"new_hidden_states",
|
||||
"new_cell_states",
|
||||
],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "T"},
|
||||
"x_lens": {0: "N"},
|
||||
"hidden_states": {1: "N"},
|
||||
"cell_states": {1: "N"},
|
||||
"encoder_out": {0: "N", 1: "T"},
|
||||
"encoder_out_lens": {0: "N"},
|
||||
"new_hidden_states": {1: "N"},
|
||||
"new_cell_states": {1: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {encoder_filename}")
|
||||
|
||||
|
||||
def export_decoder_model_onnx(
|
||||
decoder_model: nn.Module,
|
||||
decoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the decoder model to ONNX format.
|
||||
|
||||
The exported model has one input:
|
||||
|
||||
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||
|
||||
and has one output:
|
||||
|
||||
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
|
||||
|
||||
Note: The argument need_pad is fixed to False.
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The decoder model to be exported.
|
||||
decoder_filename:
|
||||
Filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||
need_pad = False # Always False, so we can use torch.jit.trace() here
|
||||
# Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script()
|
||||
# in this case
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
(y, need_pad),
|
||||
decoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["y", "need_pad"],
|
||||
output_names=["decoder_out"],
|
||||
dynamic_axes={
|
||||
"y": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
|
||||
def export_joiner_model_onnx(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the joiner model to ONNX format.
|
||||
The exported model has two inputs:
|
||||
|
||||
- encoder_out: a tensor of shape (N, encoder_out_dim)
|
||||
- decoder_out: a tensor of shape (N, decoder_out_dim)
|
||||
|
||||
and has one output:
|
||||
|
||||
- joiner_out: a tensor of shape (N, vocab_size)
|
||||
|
||||
Note: The argument project_input is fixed to True. A user should not
|
||||
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
||||
will do that for the user.
|
||||
"""
|
||||
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
|
||||
project_input = True
|
||||
# Note: It uses torch.jit.trace() internally
|
||||
torch.onnx.export(
|
||||
joiner_model,
|
||||
(encoder_out, decoder_out, project_input),
|
||||
joiner_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["encoder_out", "decoder_out", "project_input"],
|
||||
output_names=["logit"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
"logit": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
|
||||
def export_all_in_one_onnx(
|
||||
encoder_filename: str,
|
||||
decoder_filename: str,
|
||||
joiner_filename: str,
|
||||
all_in_one_filename: str,
|
||||
):
|
||||
encoder_onnx = onnx.load(encoder_filename)
|
||||
decoder_onnx = onnx.load(decoder_filename)
|
||||
joiner_onnx = onnx.load(joiner_filename)
|
||||
|
||||
encoder_onnx = onnx.compose.add_prefix(encoder_onnx, prefix="encoder/")
|
||||
decoder_onnx = onnx.compose.add_prefix(decoder_onnx, prefix="decoder/")
|
||||
joiner_onnx = onnx.compose.add_prefix(joiner_onnx, prefix="joiner/")
|
||||
|
||||
combined_model = onnx.compose.merge_models(
|
||||
encoder_onnx, decoder_onnx, io_map={}
|
||||
)
|
||||
combined_model = onnx.compose.merge_models(
|
||||
combined_model, joiner_onnx, io_map={}
|
||||
)
|
||||
onnx.save(combined_model, all_in_one_filename)
|
||||
logging.info(f"Saved to {all_in_one_filename}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
@ -483,77 +279,88 @@ def main():
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device), strict=False
|
||||
)
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device), strict=False
|
||||
)
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
|
||||
if params.onnx is True:
|
||||
opset_version = 11
|
||||
logging.info("Exporting to onnx format")
|
||||
encoder_filename = params.exp_dir / "encoder.onnx"
|
||||
export_encoder_model_onnx(
|
||||
model.encoder,
|
||||
encoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
decoder_filename = params.exp_dir / "decoder.onnx"
|
||||
export_decoder_model_onnx(
|
||||
model.decoder,
|
||||
decoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
joiner_filename = params.exp_dir / "joiner.onnx"
|
||||
export_joiner_model_onnx(
|
||||
model.joiner,
|
||||
joiner_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
all_in_one_filename = params.exp_dir / "all_in_one.onnx"
|
||||
export_all_in_one_onnx(
|
||||
encoder_filename,
|
||||
decoder_filename,
|
||||
joiner_filename,
|
||||
all_in_one_filename,
|
||||
)
|
||||
elif params.jit_trace is True:
|
||||
if params.jit_trace is True:
|
||||
logging.info("Using torch.jit.trace()")
|
||||
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
|
||||
export_encoder_model_jit_trace(model.encoder, encoder_filename)
|
||||
|
@ -289,9 +289,12 @@ def main():
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
states = encoder.get_init_states(batch_size=features.size(0), device=device)
|
||||
|
||||
encoder_out, encoder_out_lens, _ = encoder(
|
||||
x=features,
|
||||
x_lens=feature_lengths,
|
||||
states=states,
|
||||
)
|
||||
|
||||
hyps = greedy_search(
|
||||
|
@ -179,16 +179,18 @@ class RNN(EncoderInterface):
|
||||
x = x.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
|
||||
return x, lengths, new_states
|
||||
|
||||
@torch.jit.export
|
||||
def get_init_states(
|
||||
self, device: torch.device = torch.device("cpu")
|
||||
self, batch_size: int = 1, device: torch.device = torch.device("cpu")
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Get model initial states."""
|
||||
# for rnn hidden states
|
||||
hidden_states = torch.zeros(
|
||||
(self.num_encoder_layers, self.d_model), device=device
|
||||
(self.num_encoder_layers, batch_size, self.d_model), device=device
|
||||
)
|
||||
cell_states = torch.zeros(
|
||||
(self.num_encoder_layers, self.rnn_hidden_size), device=device
|
||||
(self.num_encoder_layers, batch_size, self.rnn_hidden_size),
|
||||
device=device,
|
||||
)
|
||||
return (hidden_states, cell_states)
|
||||
|
||||
@ -235,7 +237,7 @@ class RNNEncoderLayer(nn.Module):
|
||||
ScaledLinear(d_model, dim_feedforward),
|
||||
ActivationBalancer(channel_dim=-1),
|
||||
DoubleSwish(),
|
||||
nn.Dropout(),
|
||||
nn.Dropout(dropout),
|
||||
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
||||
)
|
||||
self.norm_final = BasicNorm(d_model)
|
||||
@ -763,9 +765,9 @@ if __name__ == "__main__":
|
||||
m = RNN(
|
||||
num_features=feature_dim,
|
||||
d_model=512,
|
||||
rnn_hidden_size=1536,
|
||||
rnn_hidden_size=1024,
|
||||
dim_feedforward=2048,
|
||||
num_encoder_layers=10,
|
||||
num_encoder_layers=12,
|
||||
)
|
||||
batch_size = 5
|
||||
seq_len = 20
|
||||
|
@ -116,7 +116,7 @@ class Transducer(nn.Module):
|
||||
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||
|
||||
encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup)
|
||||
encoder_out, x_lens, _ = self.encoder(x, x_lens, warmup=warmup)
|
||||
assert torch.all(x_lens > 0)
|
||||
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
|
Loading…
x
Reference in New Issue
Block a user