assert params.start_epoch>0

This commit is contained in:
yaozengwei 2022-05-07 10:32:57 +08:00
parent b1e9d2186d
commit 0bee5d058a

View File

@ -125,8 +125,8 @@ def get_parser():
"--start-epoch", "--start-epoch",
type=int, type=int,
default=1, default=1,
help="""Resume training from from this epoch. help="""Resume training from this epoch. It should be positive.
If it is positive, it will load checkpoint from If larger than 1, it will load checkpoint from
exp-dir/epoch-{start_epoch-1}.pt exp-dir/epoch-{start_epoch-1}.pt
""", """,
) )
@ -861,6 +861,7 @@ def run(rank, world_size, args):
# model_avg is only used with rank 0 # model_avg is only used with rank 0
model_avg = copy.deepcopy(model) model_avg = copy.deepcopy(model)
assert params.start_epoch > 0
checkpoints = load_checkpoint_if_available( checkpoints = load_checkpoint_if_available(
params=params, model=model, model_avg=model_avg params=params, model=model, model_avg=model_avg
) )