WIP: Periodically saving checkpoint after processing given number of batches.

This commit is contained in:
Fangjun Kuang 2022-03-20 19:42:20 +08:00
parent ad28c8c5eb
commit 1a9e8a5718

View File

@ -15,7 +15,10 @@
# limitations under the License.
import glob
import logging
import os
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
@ -151,3 +154,99 @@ def average_checkpoints(
avg[k] //= n
return avg
def save_checkpoint_with_global_batch_idx(
out_dir: Path,
global_batch_idx: int,
model: Union[nn.Module, DDP],
params: Optional[Dict[str, Any]] = None,
optimizer: Optional[Optimizer] = None,
scheduler: Optional[_LRScheduler] = None,
scaler: Optional[GradScaler] = None,
rank: int = 0,
):
"""Save training info after processing given number of batches.
Args:
out_dir:
The directory to save the checkpoint.
global_batch_idx:
The number of batches processed so far from the very start of the
training. The saved checkpoint will have the following filename:
f'out_dir / checkpoint-{global_batch_idx}.pt'
model:
The neural network model whose `state_dict` will be saved in the
checkpoint.
params:
A dict of training configurations to be saved.
optimizer:
The optimizer used in the training. Its `state_dict` will be saved.
scheduler:
The learning rate scheduler used in the training. Its `state_dict` will
be saved.
scaler:
The scaler used for mix precision training. Its `state_dict` will
be saved.
rank:
The rank ID used in DDP training of the current node. Set it to 0
if DDP is not used.
"""
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
filename = out_dir / f"checkpoint-{global_batch_idx}.pt"
save_checkpoint(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
scaler=scaler,
rank=rank,
)
def remove_checkpoints(
out_dir: Path,
topk: int,
rank: int = 0,
):
"""Remove checkpoints from the given directory.
We assume that checkpoint filename has the form `checkpoint-xxx.pt`
where xxx is a number, representing the number of processed batches
when saving that checkpoint. We sort checkpoints by filename and keep
only the `topk` checkpoints with the highest `xxx`.
Args:
out_dir:
The directory containing checkpoints to be removed.
topk:
Number of checkpoints to keep.
rank:
If using DDP for training, it is the rank of the current node.
Use 0 if no DDP is used for training.
"""
assert topk >= 1, topk
if rank != 0:
return
checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
if len(checkpoints) == 0:
logging.warn(f"No checkpoints found in {out_dir}")
return
if len(checkpoints) <= topk:
return
pattern = re.compile(r"checkpoint-([0-9]+).pt")
idx_checkpoints = [
(int(pattern.search(c).group(1)), c) for c in checkpoints
]
idx_checkpoints = sorted(idx_checkpoints, reverse=True, key=lambda x: x[0])
to_remove = idx_checkpoints[topk:]
to_remove = [ic[1] for ic in to_remove]
for c in to_remove:
os.remove(c)