from local

This commit is contained in:
dohe0342 2023-01-25 16:51:48 +09:00
parent b77772f4b1
commit 87b37383ab
2 changed files with 12 additions and 11 deletions

View File

@ -1464,17 +1464,18 @@ def run(rank, world_size, args, wb=None):
if params.print_diagnostics:
diagnostic.print_diagnostics()
break
save_checkpoint(
params=params,
model=model,
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
if epoch % 10 == 0:
save_checkpoint(
params=params,
model=model,
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
logging.info("Done!")