mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Minor fixes.
This commit is contained in:
parent
d1a4267a69
commit
539d656606
@ -17,7 +17,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
# This script converts several saved checkpoints
|
||||
# to one using model averaging.
|
||||
# to a single one using model averaging.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
@ -27,8 +27,8 @@ import torch
|
||||
from conformer import Conformer
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.utils import str2bool, AttributeDict
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -74,7 +74,7 @@ def get_parser():
|
||||
"--jit",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="""True to save a model after using torch.jit.script.
|
||||
help="""True to save a model after applying torch.jit.script.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -84,8 +84,6 @@ def get_parser():
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"a": 1,
|
||||
"b": 10,
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4,
|
||||
"use_feat_batchnorm": True,
|
||||
@ -127,6 +125,7 @@ def main():
|
||||
vgg_frontend=False,
|
||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||
)
|
||||
model.to(device)
|
||||
|
||||
if params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
@ -144,12 +143,16 @@ def main():
|
||||
if params.jit:
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
model.save(f"{params.exp_dir}/cpu_jit.pt")
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
model.save(str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
else:
|
||||
logging.info("Not using torch.jit.script")
|
||||
torch.save(
|
||||
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
|
||||
)
|
||||
# Save it using a format so that it can be loaded
|
||||
# by :func:`load_checkpoint`
|
||||
filename = params.exp_dir / "pretrained.pt"
|
||||
torch.save({"model": model.state_dict()}, str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user