From 539d656606698b8c8df48767bafb832126af42d7 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 4 Oct 2021 08:20:40 +0800 Subject: [PATCH] Minor fixes. --- egs/librispeech/ASR/conformer_ctc/export.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/export.py b/egs/librispeech/ASR/conformer_ctc/export.py index 1b258f01e..b71385417 100755 --- a/egs/librispeech/ASR/conformer_ctc/export.py +++ b/egs/librispeech/ASR/conformer_ctc/export.py @@ -17,7 +17,7 @@ # limitations under the License. # This script converts several saved checkpoints -# to one using model averaging. +# to a single one using model averaging. import argparse import logging @@ -27,8 +27,8 @@ import torch from conformer import Conformer from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.utils import str2bool, AttributeDict from icefall.lexicon import Lexicon +from icefall.utils import AttributeDict, str2bool def get_parser(): @@ -74,7 +74,7 @@ def get_parser(): "--jit", type=str2bool, default=True, - help="""True to save a model after using torch.jit.script. + help="""True to save a model after applying torch.jit.script. """, ) @@ -84,8 +84,6 @@ def get_parser(): def get_params() -> AttributeDict: params = AttributeDict( { - "a": 1, - "b": 10, "feature_dim": 80, "subsampling_factor": 4, "use_feat_batchnorm": True, @@ -127,6 +125,7 @@ def main(): vgg_frontend=False, use_feat_batchnorm=params.use_feat_batchnorm, ) + model.to(device) if params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) @@ -144,12 +143,16 @@ def main(): if params.jit: logging.info("Using torch.jit.script") model = torch.jit.script(model) - model.save(f"{params.exp_dir}/cpu_jit.pt") + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") else: logging.info("Not using torch.jit.script") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") if __name__ == "__main__":