from local

This commit is contained in:
dohe0342 2023-04-18 14:08:23 +09:00
parent 5d65ccfd71
commit 000dc1af5c
3 changed files with 2 additions and 7 deletions

Binary file not shown.

View File

@ -1263,12 +1263,7 @@ def run(rank, world_size, args, wb=None):
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
assert params.save_every_n >= params.average_period
model_avg: Optional[nn.Module] = None
if rank == 0:
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model).to(torch.float64)
model_avg = None
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(
params=params, model=model, model_avg=model_avg