mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-02 21:54:18 +00:00
WIP: Periodically saving checkpoint after processing given number of batches.
This commit is contained in:
parent
ad28c8c5eb
commit
1a9e8a5718
@ -15,7 +15,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
@ -151,3 +154,99 @@ def average_checkpoints(
|
||||
avg[k] //= n
|
||||
|
||||
return avg
|
||||
|
||||
|
||||
def save_checkpoint_with_global_batch_idx(
|
||||
out_dir: Path,
|
||||
global_batch_idx: int,
|
||||
model: Union[nn.Module, DDP],
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
scheduler: Optional[_LRScheduler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
rank: int = 0,
|
||||
):
|
||||
"""Save training info after processing given number of batches.
|
||||
|
||||
Args:
|
||||
out_dir:
|
||||
The directory to save the checkpoint.
|
||||
global_batch_idx:
|
||||
The number of batches processed so far from the very start of the
|
||||
training. The saved checkpoint will have the following filename:
|
||||
|
||||
f'out_dir / checkpoint-{global_batch_idx}.pt'
|
||||
model:
|
||||
The neural network model whose `state_dict` will be saved in the
|
||||
checkpoint.
|
||||
params:
|
||||
A dict of training configurations to be saved.
|
||||
optimizer:
|
||||
The optimizer used in the training. Its `state_dict` will be saved.
|
||||
scheduler:
|
||||
The learning rate scheduler used in the training. Its `state_dict` will
|
||||
be saved.
|
||||
scaler:
|
||||
The scaler used for mix precision training. Its `state_dict` will
|
||||
be saved.
|
||||
rank:
|
||||
The rank ID used in DDP training of the current node. Set it to 0
|
||||
if DDP is not used.
|
||||
"""
|
||||
out_dir = Path(out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
filename = out_dir / f"checkpoint-{global_batch_idx}.pt"
|
||||
save_checkpoint(
|
||||
filename=filename,
|
||||
model=model,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
|
||||
def remove_checkpoints(
|
||||
out_dir: Path,
|
||||
topk: int,
|
||||
rank: int = 0,
|
||||
):
|
||||
"""Remove checkpoints from the given directory.
|
||||
|
||||
We assume that checkpoint filename has the form `checkpoint-xxx.pt`
|
||||
where xxx is a number, representing the number of processed batches
|
||||
when saving that checkpoint. We sort checkpoints by filename and keep
|
||||
only the `topk` checkpoints with the highest `xxx`.
|
||||
|
||||
Args:
|
||||
out_dir:
|
||||
The directory containing checkpoints to be removed.
|
||||
topk:
|
||||
Number of checkpoints to keep.
|
||||
rank:
|
||||
If using DDP for training, it is the rank of the current node.
|
||||
Use 0 if no DDP is used for training.
|
||||
"""
|
||||
assert topk >= 1, topk
|
||||
if rank != 0:
|
||||
return
|
||||
checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
|
||||
|
||||
if len(checkpoints) == 0:
|
||||
logging.warn(f"No checkpoints found in {out_dir}")
|
||||
return
|
||||
|
||||
if len(checkpoints) <= topk:
|
||||
return
|
||||
|
||||
pattern = re.compile(r"checkpoint-([0-9]+).pt")
|
||||
idx_checkpoints = [
|
||||
(int(pattern.search(c).group(1)), c) for c in checkpoints
|
||||
]
|
||||
|
||||
idx_checkpoints = sorted(idx_checkpoints, reverse=True, key=lambda x: x[0])
|
||||
to_remove = idx_checkpoints[topk:]
|
||||
to_remove = [ic[1] for ic in to_remove]
|
||||
for c in to_remove:
|
||||
os.remove(c)
|
||||
|
Loading…
x
Reference in New Issue
Block a user