mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Periodically saving checkpoint after processing given number of batches (#259)
* Periodically saving checkpoint after processing given number of batches.
This commit is contained in:
parent
910e6c9306
commit
ae564f91e6
@ -58,7 +58,11 @@ from asr_datamodule import LibriSpeechAsrDataModule
|
|||||||
from beam_search import beam_search, greedy_search, modified_beam_search
|
from beam_search import beam_search, greedy_search, modified_beam_search
|
||||||
from train import get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
@ -88,6 +92,17 @@ def get_parser():
|
|||||||
"'--epoch'. ",
|
"'--epoch'. ",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg-last-n",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch and --avg are ignored and it
|
||||||
|
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
|
||||||
|
where xxx is the number of processed batches while
|
||||||
|
saving that checkpoint.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
@ -372,7 +387,12 @@ def main():
|
|||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_transducer_model(params)
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
if params.avg == 1:
|
if params.avg_last_n > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
else:
|
else:
|
||||||
start = params.epoch - params.avg + 1
|
start = params.epoch - params.avg + 1
|
||||||
|
@ -35,7 +35,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
@ -47,6 +47,7 @@ from conformer import Conformer
|
|||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@ -55,8 +56,9 @@ from torch.nn.utils import clip_grad_norm_
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from transformer import Noam
|
from transformer import Noam
|
||||||
|
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
|
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
@ -113,6 +115,15 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--start-batch",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --start-epoch is ignored and
|
||||||
|
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
@ -186,6 +197,30 @@ def get_parser():
|
|||||||
help="The seed for random generators intended for reproducibility",
|
help="The seed for random generators intended for reproducibility",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-every-n",
|
||||||
|
type=int,
|
||||||
|
default=8000,
|
||||||
|
help="""Save checkpoint after processing this number of batches"
|
||||||
|
periodically. We save checkpoint to exp-dir/ whenever
|
||||||
|
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||||
|
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||||
|
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||||
|
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--keep-last-k",
|
||||||
|
type=int,
|
||||||
|
default=20,
|
||||||
|
help="""Only keep this number of checkpoints on disk.
|
||||||
|
For instance, if it is 3, there are only 3 checkpoints
|
||||||
|
in the exp-dir with filenames `checkpoint-xxx.pt`.
|
||||||
|
It does not affect checkpoints with name `epoch-xxx.pt`.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -314,15 +349,16 @@ def load_checkpoint_if_available(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
) -> Optional[Dict[str, Any]]:
|
||||||
) -> None:
|
|
||||||
"""Load checkpoint from file.
|
"""Load checkpoint from file.
|
||||||
|
|
||||||
If params.start_epoch is positive, it will load the checkpoint from
|
If params.start_batch is positive, it will load the checkpoint from
|
||||||
`params.start_epoch - 1`. Otherwise, this function does nothing.
|
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
|
||||||
|
params.start_epoch is positive, it will load the checkpoint from
|
||||||
|
`params.start_epoch - 1`.
|
||||||
|
|
||||||
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
|
Apart from loading state dict for `model` and `optimizer` it also updates
|
||||||
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
||||||
and `best_valid_loss` in `params`.
|
and `best_valid_loss` in `params`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -332,20 +368,22 @@ def load_checkpoint_if_available(
|
|||||||
The training model.
|
The training model.
|
||||||
optimizer:
|
optimizer:
|
||||||
The optimizer that we are using.
|
The optimizer that we are using.
|
||||||
scheduler:
|
|
||||||
The learning rate scheduler we are using.
|
|
||||||
Returns:
|
Returns:
|
||||||
Return None.
|
Return a dict containing previously saved training info.
|
||||||
"""
|
"""
|
||||||
if params.start_epoch <= 0:
|
if params.start_batch > 0:
|
||||||
return
|
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
|
||||||
|
elif params.start_epoch > 0:
|
||||||
|
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
assert filename.is_file(), f"{filename} does not exist!"
|
||||||
|
|
||||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
|
||||||
saved_params = load_checkpoint(
|
saved_params = load_checkpoint(
|
||||||
filename,
|
filename,
|
||||||
model=model,
|
model=model,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
scheduler=scheduler,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
keys = [
|
keys = [
|
||||||
@ -354,10 +392,13 @@ def load_checkpoint_if_available(
|
|||||||
"batch_idx_train",
|
"batch_idx_train",
|
||||||
"best_train_loss",
|
"best_train_loss",
|
||||||
"best_valid_loss",
|
"best_valid_loss",
|
||||||
|
"cur_batch_idx",
|
||||||
]
|
]
|
||||||
for k in keys:
|
for k in keys:
|
||||||
params[k] = saved_params[k]
|
params[k] = saved_params[k]
|
||||||
|
|
||||||
|
params["start_epoch"] = saved_params["cur_epoch"]
|
||||||
|
|
||||||
return saved_params
|
return saved_params
|
||||||
|
|
||||||
|
|
||||||
@ -365,7 +406,7 @@ def save_checkpoint(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
sampler: Optional[CutSampler] = None,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save model, optimizer, scheduler and training stats to file.
|
"""Save model, optimizer, scheduler and training stats to file.
|
||||||
@ -375,6 +416,10 @@ def save_checkpoint(
|
|||||||
It is returned by :func:`get_params`.
|
It is returned by :func:`get_params`.
|
||||||
model:
|
model:
|
||||||
The training model.
|
The training model.
|
||||||
|
optimizer:
|
||||||
|
The optimizer used in the training.
|
||||||
|
sampler:
|
||||||
|
The sampler for the training dataset.
|
||||||
"""
|
"""
|
||||||
if rank != 0:
|
if rank != 0:
|
||||||
return
|
return
|
||||||
@ -384,7 +429,7 @@ def save_checkpoint(
|
|||||||
model=model,
|
model=model,
|
||||||
params=params,
|
params=params,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
scheduler=scheduler,
|
sampler=sampler,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -500,6 +545,7 @@ def train_one_epoch(
|
|||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
|
rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Train the model for one epoch.
|
"""Train the model for one epoch.
|
||||||
|
|
||||||
@ -522,6 +568,9 @@ def train_one_epoch(
|
|||||||
Writer to write log messages to tensorboard.
|
Writer to write log messages to tensorboard.
|
||||||
world_size:
|
world_size:
|
||||||
Number of nodes in DDP training. If it is 1, DDP is disabled.
|
Number of nodes in DDP training. If it is 1, DDP is disabled.
|
||||||
|
rank:
|
||||||
|
The rank of the node in DDP training. If no DDP is used, it should
|
||||||
|
be set to 0.
|
||||||
"""
|
"""
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
@ -566,7 +615,13 @@ def train_one_epoch(
|
|||||||
else:
|
else:
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
|
if batch_idx < cur_batch_idx:
|
||||||
|
continue
|
||||||
|
cur_batch_idx = batch_idx
|
||||||
|
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
@ -591,6 +646,27 @@ def train_one_epoch(
|
|||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
if (
|
||||||
|
params.batch_idx_train > 0
|
||||||
|
and params.batch_idx_train % params.save_every_n == 0
|
||||||
|
):
|
||||||
|
params.cur_batch_idx = batch_idx
|
||||||
|
save_checkpoint_with_global_batch_idx(
|
||||||
|
out_dir=params.exp_dir,
|
||||||
|
global_batch_idx=params.batch_idx_train,
|
||||||
|
model=model,
|
||||||
|
params=params,
|
||||||
|
optimizer=optimizer,
|
||||||
|
sampler=train_dl.sampler,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
del params.cur_batch_idx
|
||||||
|
remove_checkpoints(
|
||||||
|
out_dir=params.exp_dir,
|
||||||
|
topk=params.keep_last_k,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Epoch {params.cur_epoch}, "
|
f"Epoch {params.cur_epoch}, "
|
||||||
@ -598,8 +674,6 @@ def train_one_epoch(
|
|||||||
f"tot_loss[{tot_loss}], batch size: {batch_size}"
|
f"tot_loss[{tot_loss}], batch size: {batch_size}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer, "train/current_", params.batch_idx_train
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
@ -723,7 +797,14 @@ def run(rank, world_size, args):
|
|||||||
logging.info(f"After removing short and long utterances: {num_left}")
|
logging.info(f"After removing short and long utterances: {num_left}")
|
||||||
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
|
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
|
||||||
|
|
||||||
train_dl = librispeech.train_dataloaders(train_cuts)
|
if checkpoints and "sampler" in checkpoints:
|
||||||
|
sampler_state_dict = checkpoints["sampler"]
|
||||||
|
else:
|
||||||
|
sampler_state_dict = None
|
||||||
|
|
||||||
|
train_dl = librispeech.train_dataloaders(
|
||||||
|
train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
|
)
|
||||||
|
|
||||||
valid_cuts = librispeech.dev_clean_cuts()
|
valid_cuts = librispeech.dev_clean_cuts()
|
||||||
valid_cuts += librispeech.dev_other_cuts()
|
valid_cuts += librispeech.dev_other_cuts()
|
||||||
@ -762,12 +843,14 @@ def run(rank, world_size, args):
|
|||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
tb_writer=tb_writer,
|
tb_writer=tb_writer,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
|
sampler=train_dl.sampler,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -21,6 +21,7 @@ import inspect
|
|||||||
import logging
|
import logging
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
|
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
|
||||||
from lhotse.dataset import (
|
from lhotse.dataset import (
|
||||||
@ -181,8 +182,18 @@ class LibriSpeechAsrDataModule:
|
|||||||
"with training dataset. ",
|
"with training dataset. ",
|
||||||
)
|
)
|
||||||
|
|
||||||
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
|
def train_dataloaders(
|
||||||
|
self,
|
||||||
|
cuts_train: CutSet,
|
||||||
|
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> DataLoader:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
cuts_train:
|
||||||
|
CutSet for training.
|
||||||
|
sampler_state_dict:
|
||||||
|
The state dict for the training sampler.
|
||||||
|
"""
|
||||||
transforms = []
|
transforms = []
|
||||||
if self.args.enable_musan:
|
if self.args.enable_musan:
|
||||||
logging.info("Enable MUSAN")
|
logging.info("Enable MUSAN")
|
||||||
@ -286,6 +297,10 @@ class LibriSpeechAsrDataModule:
|
|||||||
)
|
)
|
||||||
logging.info("About to create train dataloader")
|
logging.info("About to create train dataloader")
|
||||||
|
|
||||||
|
if sampler_state_dict is not None:
|
||||||
|
logging.info("Loading sampler state dict")
|
||||||
|
train_sampler.load_state_dict(sampler_state_dict)
|
||||||
|
|
||||||
train_dl = DataLoader(
|
train_dl = DataLoader(
|
||||||
train,
|
train,
|
||||||
sampler=train_sampler,
|
sampler=train_sampler,
|
||||||
|
@ -15,12 +15,16 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import glob
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
@ -34,6 +38,7 @@ def save_checkpoint(
|
|||||||
optimizer: Optional[Optimizer] = None,
|
optimizer: Optional[Optimizer] = None,
|
||||||
scheduler: Optional[_LRScheduler] = None,
|
scheduler: Optional[_LRScheduler] = None,
|
||||||
scaler: Optional[GradScaler] = None,
|
scaler: Optional[GradScaler] = None,
|
||||||
|
sampler: Optional[CutSampler] = None,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save training information to a file.
|
"""Save training information to a file.
|
||||||
@ -69,6 +74,7 @@ def save_checkpoint(
|
|||||||
"optimizer": optimizer.state_dict() if optimizer is not None else None,
|
"optimizer": optimizer.state_dict() if optimizer is not None else None,
|
||||||
"scheduler": scheduler.state_dict() if scheduler is not None else None,
|
"scheduler": scheduler.state_dict() if scheduler is not None else None,
|
||||||
"grad_scaler": scaler.state_dict() if scaler is not None else None,
|
"grad_scaler": scaler.state_dict() if scaler is not None else None,
|
||||||
|
"sampler": sampler.state_dict() if sampler is not None else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
if params:
|
if params:
|
||||||
@ -85,6 +91,7 @@ def load_checkpoint(
|
|||||||
optimizer: Optional[Optimizer] = None,
|
optimizer: Optional[Optimizer] = None,
|
||||||
scheduler: Optional[_LRScheduler] = None,
|
scheduler: Optional[_LRScheduler] = None,
|
||||||
scaler: Optional[GradScaler] = None,
|
scaler: Optional[GradScaler] = None,
|
||||||
|
sampler: Optional[CutSampler] = None,
|
||||||
strict: bool = False,
|
strict: bool = False,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@ -117,6 +124,7 @@ def load_checkpoint(
|
|||||||
load("optimizer", optimizer)
|
load("optimizer", optimizer)
|
||||||
load("scheduler", scheduler)
|
load("scheduler", scheduler)
|
||||||
load("grad_scaler", scaler)
|
load("grad_scaler", scaler)
|
||||||
|
load("sampler", sampler)
|
||||||
|
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
@ -151,3 +159,120 @@ def average_checkpoints(
|
|||||||
avg[k] //= n
|
avg[k] //= n
|
||||||
|
|
||||||
return avg
|
return avg
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint_with_global_batch_idx(
|
||||||
|
out_dir: Path,
|
||||||
|
global_batch_idx: int,
|
||||||
|
model: Union[nn.Module, DDP],
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
optimizer: Optional[Optimizer] = None,
|
||||||
|
scheduler: Optional[_LRScheduler] = None,
|
||||||
|
scaler: Optional[GradScaler] = None,
|
||||||
|
sampler: Optional[CutSampler] = None,
|
||||||
|
rank: int = 0,
|
||||||
|
):
|
||||||
|
"""Save training info after processing given number of batches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
out_dir:
|
||||||
|
The directory to save the checkpoint.
|
||||||
|
global_batch_idx:
|
||||||
|
The number of batches processed so far from the very start of the
|
||||||
|
training. The saved checkpoint will have the following filename:
|
||||||
|
|
||||||
|
f'out_dir / checkpoint-{global_batch_idx}.pt'
|
||||||
|
model:
|
||||||
|
The neural network model whose `state_dict` will be saved in the
|
||||||
|
checkpoint.
|
||||||
|
params:
|
||||||
|
A dict of training configurations to be saved.
|
||||||
|
optimizer:
|
||||||
|
The optimizer used in the training. Its `state_dict` will be saved.
|
||||||
|
scheduler:
|
||||||
|
The learning rate scheduler used in the training. Its `state_dict` will
|
||||||
|
be saved.
|
||||||
|
scaler:
|
||||||
|
The scaler used for mix precision training. Its `state_dict` will
|
||||||
|
be saved.
|
||||||
|
sampler:
|
||||||
|
The sampler used in the training dataset.
|
||||||
|
rank:
|
||||||
|
The rank ID used in DDP training of the current node. Set it to 0
|
||||||
|
if DDP is not used.
|
||||||
|
"""
|
||||||
|
out_dir = Path(out_dir)
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
filename = out_dir / f"checkpoint-{global_batch_idx}.pt"
|
||||||
|
save_checkpoint(
|
||||||
|
filename=filename,
|
||||||
|
model=model,
|
||||||
|
params=params,
|
||||||
|
optimizer=optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
scaler=scaler,
|
||||||
|
sampler=sampler,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def find_checkpoints(out_dir: Path) -> List[str]:
|
||||||
|
"""Find all available checkpoints in a directory.
|
||||||
|
|
||||||
|
The checkpoint filenames have the form: `checkpoint-xxx.pt`
|
||||||
|
where xxx is a numerical value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
out_dir:
|
||||||
|
The directory where to search for checkpoints.
|
||||||
|
Returns:
|
||||||
|
Return a list of checkpoint filenames, sorted in descending
|
||||||
|
order by the numerical value in the filename.
|
||||||
|
"""
|
||||||
|
checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
|
||||||
|
pattern = re.compile(r"checkpoint-([0-9]+).pt")
|
||||||
|
idx_checkpoints = [
|
||||||
|
(int(pattern.search(c).group(1)), c) for c in checkpoints
|
||||||
|
]
|
||||||
|
|
||||||
|
idx_checkpoints = sorted(idx_checkpoints, reverse=True, key=lambda x: x[0])
|
||||||
|
ans = [ic[1] for ic in idx_checkpoints]
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def remove_checkpoints(
|
||||||
|
out_dir: Path,
|
||||||
|
topk: int,
|
||||||
|
rank: int = 0,
|
||||||
|
):
|
||||||
|
"""Remove checkpoints from the given directory.
|
||||||
|
|
||||||
|
We assume that checkpoint filename has the form `checkpoint-xxx.pt`
|
||||||
|
where xxx is a number, representing the number of processed batches
|
||||||
|
when saving that checkpoint. We sort checkpoints by filename and keep
|
||||||
|
only the `topk` checkpoints with the highest `xxx`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
out_dir:
|
||||||
|
The directory containing checkpoints to be removed.
|
||||||
|
topk:
|
||||||
|
Number of checkpoints to keep.
|
||||||
|
rank:
|
||||||
|
If using DDP for training, it is the rank of the current node.
|
||||||
|
Use 0 if no DDP is used for training.
|
||||||
|
"""
|
||||||
|
assert topk >= 1, topk
|
||||||
|
if rank != 0:
|
||||||
|
return
|
||||||
|
checkpoints = find_checkpoints(out_dir)
|
||||||
|
|
||||||
|
if len(checkpoints) == 0:
|
||||||
|
logging.warn(f"No checkpoints found in {out_dir}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(checkpoints) <= topk:
|
||||||
|
return
|
||||||
|
|
||||||
|
to_remove = checkpoints[topk:]
|
||||||
|
for c in to_remove:
|
||||||
|
os.remove(c)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user