Minor fixes.

This commit is contained in:
Fangjun Kuang 2021-10-04 08:20:40 +08:00
parent d1a4267a69
commit 539d656606

View File

@ -17,7 +17,7 @@
# limitations under the License.
# This script converts several saved checkpoints
# to one using model averaging.
# to a single one using model averaging.
import argparse
import logging
@ -27,8 +27,8 @@ import torch
from conformer import Conformer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.utils import str2bool, AttributeDict
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
def get_parser():
@ -74,7 +74,7 @@ def get_parser():
"--jit",
type=str2bool,
default=True,
help="""True to save a model after using torch.jit.script.
help="""True to save a model after applying torch.jit.script.
""",
)
@ -84,8 +84,6 @@ def get_parser():
def get_params() -> AttributeDict:
params = AttributeDict(
{
"a": 1,
"b": 10,
"feature_dim": 80,
"subsampling_factor": 4,
"use_feat_batchnorm": True,
@ -127,6 +125,7 @@ def main():
vgg_frontend=False,
use_feat_batchnorm=params.use_feat_batchnorm,
)
model.to(device)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
@ -144,12 +143,16 @@ def main():
if params.jit:
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
model.save(f"{params.exp_dir}/cpu_jit.pt")
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":