From 1a9e8a5718eb0134f0355b7ca3296a16976e7dd3 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 20 Mar 2022 19:42:20 +0800 Subject: [PATCH] WIP: Periodically saving checkpoint after processing given number of batches. --- icefall/checkpoint.py | 99 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index dbe3c1315..b4c93f03a 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -15,7 +15,10 @@ # limitations under the License. +import glob import logging +import os +import re from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -151,3 +154,99 @@ def average_checkpoints( avg[k] //= n return avg + + +def save_checkpoint_with_global_batch_idx( + out_dir: Path, + global_batch_idx: int, + model: Union[nn.Module, DDP], + params: Optional[Dict[str, Any]] = None, + optimizer: Optional[Optimizer] = None, + scheduler: Optional[_LRScheduler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +): + """Save training info after processing given number of batches. + + Args: + out_dir: + The directory to save the checkpoint. + global_batch_idx: + The number of batches processed so far from the very start of the + training. The saved checkpoint will have the following filename: + + f'out_dir / checkpoint-{global_batch_idx}.pt' + model: + The neural network model whose `state_dict` will be saved in the + checkpoint. + params: + A dict of training configurations to be saved. + optimizer: + The optimizer used in the training. Its `state_dict` will be saved. + scheduler: + The learning rate scheduler used in the training. Its `state_dict` will + be saved. + scaler: + The scaler used for mix precision training. Its `state_dict` will + be saved. + rank: + The rank ID used in DDP training of the current node. Set it to 0 + if DDP is not used. + """ + out_dir = Path(out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + filename = out_dir / f"checkpoint-{global_batch_idx}.pt" + save_checkpoint( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + rank=rank, + ) + + +def remove_checkpoints( + out_dir: Path, + topk: int, + rank: int = 0, +): + """Remove checkpoints from the given directory. + + We assume that checkpoint filename has the form `checkpoint-xxx.pt` + where xxx is a number, representing the number of processed batches + when saving that checkpoint. We sort checkpoints by filename and keep + only the `topk` checkpoints with the highest `xxx`. + + Args: + out_dir: + The directory containing checkpoints to be removed. + topk: + Number of checkpoints to keep. + rank: + If using DDP for training, it is the rank of the current node. + Use 0 if no DDP is used for training. + """ + assert topk >= 1, topk + if rank != 0: + return + checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt")) + + if len(checkpoints) == 0: + logging.warn(f"No checkpoints found in {out_dir}") + return + + if len(checkpoints) <= topk: + return + + pattern = re.compile(r"checkpoint-([0-9]+).pt") + idx_checkpoints = [ + (int(pattern.search(c).group(1)), c) for c in checkpoints + ] + + idx_checkpoints = sorted(idx_checkpoints, reverse=True, key=lambda x: x[0]) + to_remove = idx_checkpoints[topk:] + to_remove = [ic[1] for ic in to_remove] + for c in to_remove: + os.remove(c)