add export.py

This commit is contained in:
Yifan Yang 2023-06-02 11:19:50 +08:00
parent afc9f9c413
commit 90e758b908
3 changed files with 122 additions and 56 deletions

View File

@ -112,7 +112,6 @@ def main():
for torch_v, onnx_v in zip( for torch_v, onnx_v in zip(
(torch_log_prob, torch_h0, torch_c0), (onnx_log_prob, onnx_h0, onnx_c0) (torch_log_prob, torch_h0, torch_c0), (onnx_log_prob, onnx_h0, onnx_c0)
): ):
assert torch.allclose(torch_v, onnx_v, atol=1e-5), ( assert torch.allclose(torch_v, onnx_v, atol=1e-5), (
torch_v.shape, torch_v.shape,
onnx_v.shape, onnx_v.shape,

View File

@ -5,16 +5,16 @@
import argparse import argparse
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Dict
import onnx import onnx
import torch import torch
from model import RnnLmModel from model import RnnLmModel
from onnxruntime.quantization import QuantType, quantize_dynamic from onnxruntime.quantization import QuantType, quantize_dynamic
from train import get_params
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.utils import AttributeDict, str2bool from icefall.utils import AttributeDict, str2bool
from typing import Dict
from train import get_params
def add_meta_data(filename: str, meta_data: Dict[str, str]): def add_meta_data(filename: str, meta_data: Dict[str, str]):

View File

@ -1,5 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# #
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Yifan Yang)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -25,7 +27,12 @@ from pathlib import Path
import torch import torch
from model import RnnLmModel from model import RnnLmModel
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import AttributeDict, str2bool from icefall.utils import AttributeDict, str2bool
@ -37,18 +44,10 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--epoch", "--epoch",
type=int, type=int,
default=29, default=30,
help="It specifies the checkpoint to use for decoding." help="""It specifies the checkpoint to use for decoding.
"Note: Epoch counts from 0.", Note: Epoch counts from 1.
) You can specify --avg to use more checkpoints for model averaging.""",
parser.add_argument(
"--avg",
type=int,
default=5,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
) )
parser.add_argument( parser.add_argument(
@ -61,6 +60,35 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="rnnlm/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument( parser.add_argument(
"--vocab-size", "--vocab-size",
type=int, type=int,
@ -98,20 +126,14 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--exp-dir",
type=str,
default="rnn_lm/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument( parser.add_argument(
"--jit", "--jit",
type=str2bool, type=str2bool,
default=True, default=False,
help="""True to save a model after applying torch.jit.script. help="""True to save a model after applying torch.jit.script.
It will generate a file named cpu_jit.pt
Check ./jit_pretrained.py for how to use it.
""", """,
) )
@ -144,13 +166,15 @@ def main():
model.to(device) model.to(device)
if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg : params.avg
] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg: elif len(filenames) < params.avg:
raise ValueError( raise ValueError(
@ -159,21 +183,64 @@ def main():
) )
logging.info(f"averaging {filenames}") logging.info(f"averaging {filenames}")
model.to(device) model.to(device)
model.load_state_dict( model.load_state_dict(average_checkpoints(filenames, device=device))
average_checkpoints(filenames, device=device), strict=False
)
elif params.avg == 1: elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else: else:
start = params.epoch - params.avg + 1 start = params.epoch - params.avg + 1
filenames = [] filenames = []
for i in range(start, params.epoch + 1): for i in range(start, params.epoch + 1):
if i >= 0: if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt") filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}") logging.info(f"averaging {filenames}")
model.to(device) model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict( model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
) )
model.to("cpu") model.to("cpu")