mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 06:04:18 +00:00
WIP: Periodically saving checkpoint after processing given number of batches.
This commit is contained in:
parent
ad28c8c5eb
commit
1a9e8a5718
@ -15,7 +15,10 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import glob
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
@ -151,3 +154,99 @@ def average_checkpoints(
|
|||||||
avg[k] //= n
|
avg[k] //= n
|
||||||
|
|
||||||
return avg
|
return avg
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint_with_global_batch_idx(
|
||||||
|
out_dir: Path,
|
||||||
|
global_batch_idx: int,
|
||||||
|
model: Union[nn.Module, DDP],
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
optimizer: Optional[Optimizer] = None,
|
||||||
|
scheduler: Optional[_LRScheduler] = None,
|
||||||
|
scaler: Optional[GradScaler] = None,
|
||||||
|
rank: int = 0,
|
||||||
|
):
|
||||||
|
"""Save training info after processing given number of batches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
out_dir:
|
||||||
|
The directory to save the checkpoint.
|
||||||
|
global_batch_idx:
|
||||||
|
The number of batches processed so far from the very start of the
|
||||||
|
training. The saved checkpoint will have the following filename:
|
||||||
|
|
||||||
|
f'out_dir / checkpoint-{global_batch_idx}.pt'
|
||||||
|
model:
|
||||||
|
The neural network model whose `state_dict` will be saved in the
|
||||||
|
checkpoint.
|
||||||
|
params:
|
||||||
|
A dict of training configurations to be saved.
|
||||||
|
optimizer:
|
||||||
|
The optimizer used in the training. Its `state_dict` will be saved.
|
||||||
|
scheduler:
|
||||||
|
The learning rate scheduler used in the training. Its `state_dict` will
|
||||||
|
be saved.
|
||||||
|
scaler:
|
||||||
|
The scaler used for mix precision training. Its `state_dict` will
|
||||||
|
be saved.
|
||||||
|
rank:
|
||||||
|
The rank ID used in DDP training of the current node. Set it to 0
|
||||||
|
if DDP is not used.
|
||||||
|
"""
|
||||||
|
out_dir = Path(out_dir)
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
filename = out_dir / f"checkpoint-{global_batch_idx}.pt"
|
||||||
|
save_checkpoint(
|
||||||
|
filename=filename,
|
||||||
|
model=model,
|
||||||
|
params=params,
|
||||||
|
optimizer=optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
scaler=scaler,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_checkpoints(
|
||||||
|
out_dir: Path,
|
||||||
|
topk: int,
|
||||||
|
rank: int = 0,
|
||||||
|
):
|
||||||
|
"""Remove checkpoints from the given directory.
|
||||||
|
|
||||||
|
We assume that checkpoint filename has the form `checkpoint-xxx.pt`
|
||||||
|
where xxx is a number, representing the number of processed batches
|
||||||
|
when saving that checkpoint. We sort checkpoints by filename and keep
|
||||||
|
only the `topk` checkpoints with the highest `xxx`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
out_dir:
|
||||||
|
The directory containing checkpoints to be removed.
|
||||||
|
topk:
|
||||||
|
Number of checkpoints to keep.
|
||||||
|
rank:
|
||||||
|
If using DDP for training, it is the rank of the current node.
|
||||||
|
Use 0 if no DDP is used for training.
|
||||||
|
"""
|
||||||
|
assert topk >= 1, topk
|
||||||
|
if rank != 0:
|
||||||
|
return
|
||||||
|
checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
|
||||||
|
|
||||||
|
if len(checkpoints) == 0:
|
||||||
|
logging.warn(f"No checkpoints found in {out_dir}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(checkpoints) <= topk:
|
||||||
|
return
|
||||||
|
|
||||||
|
pattern = re.compile(r"checkpoint-([0-9]+).pt")
|
||||||
|
idx_checkpoints = [
|
||||||
|
(int(pattern.search(c).group(1)), c) for c in checkpoints
|
||||||
|
]
|
||||||
|
|
||||||
|
idx_checkpoints = sorted(idx_checkpoints, reverse=True, key=lambda x: x[0])
|
||||||
|
to_remove = idx_checkpoints[topk:]
|
||||||
|
to_remove = [ic[1] for ic in to_remove]
|
||||||
|
for c in to_remove:
|
||||||
|
os.remove(c)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user