mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Minor fixes.
This commit is contained in:
parent
d1a4267a69
commit
539d656606
@ -17,7 +17,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
# This script converts several saved checkpoints
|
# This script converts several saved checkpoints
|
||||||
# to one using model averaging.
|
# to a single one using model averaging.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
@ -27,8 +27,8 @@ import torch
|
|||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.utils import str2bool, AttributeDict
|
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.utils import AttributeDict, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -74,7 +74,7 @@ def get_parser():
|
|||||||
"--jit",
|
"--jit",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help="""True to save a model after using torch.jit.script.
|
help="""True to save a model after applying torch.jit.script.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -84,8 +84,6 @@ def get_parser():
|
|||||||
def get_params() -> AttributeDict:
|
def get_params() -> AttributeDict:
|
||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
"a": 1,
|
|
||||||
"b": 10,
|
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
"use_feat_batchnorm": True,
|
"use_feat_batchnorm": True,
|
||||||
@ -127,6 +125,7 @@ def main():
|
|||||||
vgg_frontend=False,
|
vgg_frontend=False,
|
||||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||||
)
|
)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
if params.avg == 1:
|
if params.avg == 1:
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
@ -144,12 +143,16 @@ def main():
|
|||||||
if params.jit:
|
if params.jit:
|
||||||
logging.info("Using torch.jit.script")
|
logging.info("Using torch.jit.script")
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
model.save(f"{params.exp_dir}/cpu_jit.pt")
|
filename = params.exp_dir / "cpu_jit.pt"
|
||||||
|
model.save(str(filename))
|
||||||
|
logging.info(f"Saved to {filename}")
|
||||||
else:
|
else:
|
||||||
logging.info("Not using torch.jit.script")
|
logging.info("Not using torch.jit.script")
|
||||||
torch.save(
|
# Save it using a format so that it can be loaded
|
||||||
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
|
# by :func:`load_checkpoint`
|
||||||
)
|
filename = params.exp_dir / "pretrained.pt"
|
||||||
|
torch.save({"model": model.state_dict()}, str(filename))
|
||||||
|
logging.info(f"Saved to {filename}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user