Minor fixes.

This commit is contained in:
Fangjun Kuang 2021-10-04 08:20:40 +08:00
parent d1a4267a69
commit 539d656606

View File

@ -17,7 +17,7 @@
# limitations under the License. # limitations under the License.
# This script converts several saved checkpoints # This script converts several saved checkpoints
# to one using model averaging. # to a single one using model averaging.
import argparse import argparse
import logging import logging
@ -27,8 +27,8 @@ import torch
from conformer import Conformer from conformer import Conformer
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.utils import str2bool, AttributeDict
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
def get_parser(): def get_parser():
@ -74,7 +74,7 @@ def get_parser():
"--jit", "--jit",
type=str2bool, type=str2bool,
default=True, default=True,
help="""True to save a model after using torch.jit.script. help="""True to save a model after applying torch.jit.script.
""", """,
) )
@ -84,8 +84,6 @@ def get_parser():
def get_params() -> AttributeDict: def get_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
"a": 1,
"b": 10,
"feature_dim": 80, "feature_dim": 80,
"subsampling_factor": 4, "subsampling_factor": 4,
"use_feat_batchnorm": True, "use_feat_batchnorm": True,
@ -127,6 +125,7 @@ def main():
vgg_frontend=False, vgg_frontend=False,
use_feat_batchnorm=params.use_feat_batchnorm, use_feat_batchnorm=params.use_feat_batchnorm,
) )
model.to(device)
if params.avg == 1: if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
@ -144,12 +143,16 @@ def main():
if params.jit: if params.jit:
logging.info("Using torch.jit.script") logging.info("Using torch.jit.script")
model = torch.jit.script(model) model = torch.jit.script(model)
model.save(f"{params.exp_dir}/cpu_jit.pt") filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else: else:
logging.info("Not using torch.jit.script") logging.info("Not using torch.jit.script")
torch.save( # Save it using a format so that it can be loaded
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" # by :func:`load_checkpoint`
) filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__": if __name__ == "__main__":