WIP: Periodically saving checkpoint after processing given number of batches.

This commit is contained in:
Fangjun Kuang 2022-03-20 19:42:20 +08:00
parent ad28c8c5eb
commit 1a9e8a5718

View File

@ -15,7 +15,10 @@
# limitations under the License. # limitations under the License.
import glob
import logging import logging
import os
import re
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
@ -151,3 +154,99 @@ def average_checkpoints(
avg[k] //= n avg[k] //= n
return avg return avg
def save_checkpoint_with_global_batch_idx(
out_dir: Path,
global_batch_idx: int,
model: Union[nn.Module, DDP],
params: Optional[Dict[str, Any]] = None,
optimizer: Optional[Optimizer] = None,
scheduler: Optional[_LRScheduler] = None,
scaler: Optional[GradScaler] = None,
rank: int = 0,
):
"""Save training info after processing given number of batches.
Args:
out_dir:
The directory to save the checkpoint.
global_batch_idx:
The number of batches processed so far from the very start of the
training. The saved checkpoint will have the following filename:
f'out_dir / checkpoint-{global_batch_idx}.pt'
model:
The neural network model whose `state_dict` will be saved in the
checkpoint.
params:
A dict of training configurations to be saved.
optimizer:
The optimizer used in the training. Its `state_dict` will be saved.
scheduler:
The learning rate scheduler used in the training. Its `state_dict` will
be saved.
scaler:
The scaler used for mix precision training. Its `state_dict` will
be saved.
rank:
The rank ID used in DDP training of the current node. Set it to 0
if DDP is not used.
"""
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
filename = out_dir / f"checkpoint-{global_batch_idx}.pt"
save_checkpoint(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
scaler=scaler,
rank=rank,
)
def remove_checkpoints(
out_dir: Path,
topk: int,
rank: int = 0,
):
"""Remove checkpoints from the given directory.
We assume that checkpoint filename has the form `checkpoint-xxx.pt`
where xxx is a number, representing the number of processed batches
when saving that checkpoint. We sort checkpoints by filename and keep
only the `topk` checkpoints with the highest `xxx`.
Args:
out_dir:
The directory containing checkpoints to be removed.
topk:
Number of checkpoints to keep.
rank:
If using DDP for training, it is the rank of the current node.
Use 0 if no DDP is used for training.
"""
assert topk >= 1, topk
if rank != 0:
return
checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
if len(checkpoints) == 0:
logging.warn(f"No checkpoints found in {out_dir}")
return
if len(checkpoints) <= topk:
return
pattern = re.compile(r"checkpoint-([0-9]+).pt")
idx_checkpoints = [
(int(pattern.search(c).group(1)), c) for c in checkpoints
]
idx_checkpoints = sorted(idx_checkpoints, reverse=True, key=lambda x: x[0])
to_remove = idx_checkpoints[topk:]
to_remove = [ic[1] for ic in to_remove]
for c in to_remove:
os.remove(c)