Save checkpoints periodically.

This commit is contained in:
Fangjun Kuang 2022-03-20 22:40:14 +08:00
parent 1a9e8a5718
commit a1189615f3
4 changed files with 177 additions and 33 deletions

View File

@ -58,7 +58,11 @@ from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import beam_search, greedy_search, modified_beam_search
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import (
AttributeDict,
setup_logger,
@ -88,6 +92,17 @@ def get_parser():
"'--epoch'. ",
)
parser.add_argument(
"--avg-last-n",
type=int,
default=0,
help="""If positive, --epoch and --avg are ignored and it
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
where xxx is the number of processed batches while
saving that checkpoint.
""",
)
parser.add_argument(
"--exp-dir",
type=str,
@ -372,7 +387,12 @@ def main():
logging.info("About to create model")
model = get_transducer_model(params)
if params.avg == 1:
if params.avg_last_n > 0:
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1

View File

@ -35,7 +35,7 @@ import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
from typing import Any, Dict, Optional, Tuple
import k2
import sentencepiece as spm
@ -47,6 +47,7 @@ from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import Transducer
from torch import Tensor
@ -55,8 +56,9 @@ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import (
@ -113,6 +115,15 @@ def get_parser():
""",
)
parser.add_argument(
"--start-batch",
type=int,
default=0,
help="""If positive, --start-epoch is ignored and
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
@ -186,6 +197,30 @@ def get_parser():
help="The seed for random generators intended for reproducibility",
)
parser.add_argument(
"--save-every-n",
type=int,
default=8000,
help="""Save checkpoint after processing this number of batches"
periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0.
""",
)
parser.add_argument(
"--keep-last-k",
type=int,
default=10,
help="""Only keep this number of checkpoints on disk.
For instance, if it is 3, there are only 3 checkpoints
in the exp-dir with filenames `checkpoint-xxx.pt`.
It does not affect checkpoints with name `epoch-xxx.pt`.
""",
)
return parser
@ -314,15 +349,16 @@ def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
) -> Optional[Dict[str, Any]]:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
If params.start_batch is positive, it will load the checkpoint from
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
Apart from loading state dict for `model` and `optimizer` it also updates
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
@ -332,20 +368,22 @@ def load_checkpoint_if_available(
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
Return a dict containing previously saved training info.
"""
if params.start_epoch <= 0:
return
if params.start_batch > 0:
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
elif params.start_epoch > 0:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else:
return None
assert filename.is_file(), f"{filename} does not exist!"
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
@ -354,10 +392,13 @@ def load_checkpoint_if_available(
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
"cur_batch_idx",
]
for k in keys:
params[k] = saved_params[k]
params["start_epoch"] = saved_params["cur_epoch"]
return saved_params
@ -365,7 +406,7 @@ def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
sampler: Optional[CutSampler] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -375,6 +416,10 @@ def save_checkpoint(
It is returned by :func:`get_params`.
model:
The training model.
optimizer:
The optimizer used in the training.
sampler:
The sampler for the training dataset.
"""
if rank != 0:
return
@ -384,7 +429,7 @@ def save_checkpoint(
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
sampler=sampler,
rank=rank,
)
@ -500,6 +545,7 @@ def train_one_epoch(
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
) -> None:
"""Train the model for one epoch.
@ -522,6 +568,9 @@ def train_one_epoch(
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
rank:
The rank of the node in DDP training. If no DDP is used, it should
be set to 0.
"""
model.train()
@ -566,7 +615,13 @@ def train_one_epoch(
else:
optimizer.step()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
@ -591,6 +646,27 @@ def train_one_epoch(
optimizer.zero_grad()
if (
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
model=model,
params=params,
optimizer=optimizer,
sampler=train_dl.sampler,
rank=rank,
)
del params.cur_batch_idx
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
rank=rank,
)
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, "
@ -598,8 +674,6 @@ def train_one_epoch(
f"tot_loss[{tot_loss}], batch size: {batch_size}"
)
if batch_idx % params.log_interval == 0:
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
@ -723,7 +797,14 @@ def run(rank, world_size, args):
logging.info(f"After removing short and long utterances: {num_left}")
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
train_dl = librispeech.train_dataloaders(train_cuts)
if checkpoints and "sampler" in checkpoints:
sampler_state_dict = checkpoints["sampler"]
else:
sampler_state_dict = None
train_dl = librispeech.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
@ -762,12 +843,14 @@ def run(rank, world_size, args):
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
)
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
sampler=train_dl.sampler,
rank=rank,
)

View File

@ -21,6 +21,7 @@ import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import (
@ -181,8 +182,18 @@ class LibriSpeechAsrDataModule:
"with training dataset. ",
)
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
@ -286,6 +297,10 @@ class LibriSpeechAsrDataModule:
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
train_dl = DataLoader(
train,
sampler=train_sampler,

View File

@ -24,6 +24,7 @@ from typing import Any, Dict, List, Optional, Union
import torch
import torch.nn as nn
from lhotse.dataset.sampling.base import CutSampler
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
@ -37,6 +38,7 @@ def save_checkpoint(
optimizer: Optional[Optimizer] = None,
scheduler: Optional[_LRScheduler] = None,
scaler: Optional[GradScaler] = None,
sampler: Optional[CutSampler] = None,
rank: int = 0,
) -> None:
"""Save training information to a file.
@ -72,6 +74,7 @@ def save_checkpoint(
"optimizer": optimizer.state_dict() if optimizer is not None else None,
"scheduler": scheduler.state_dict() if scheduler is not None else None,
"grad_scaler": scaler.state_dict() if scaler is not None else None,
"sampler": sampler.state_dict() if sampler is not None else None,
}
if params:
@ -88,6 +91,7 @@ def load_checkpoint(
optimizer: Optional[Optimizer] = None,
scheduler: Optional[_LRScheduler] = None,
scaler: Optional[GradScaler] = None,
sampler: Optional[CutSampler] = None,
strict: bool = False,
) -> Dict[str, Any]:
"""
@ -120,6 +124,7 @@ def load_checkpoint(
load("optimizer", optimizer)
load("scheduler", scheduler)
load("grad_scaler", scaler)
load("sampler", sampler)
return checkpoint
@ -164,6 +169,7 @@ def save_checkpoint_with_global_batch_idx(
optimizer: Optional[Optimizer] = None,
scheduler: Optional[_LRScheduler] = None,
scaler: Optional[GradScaler] = None,
sampler: Optional[CutSampler] = None,
rank: int = 0,
):
"""Save training info after processing given number of batches.
@ -189,6 +195,8 @@ def save_checkpoint_with_global_batch_idx(
scaler:
The scaler used for mix precision training. Its `state_dict` will
be saved.
sampler:
The sampler used in the training dataset.
rank:
The rank ID used in DDP training of the current node. Set it to 0
if DDP is not used.
@ -203,10 +211,35 @@ def save_checkpoint_with_global_batch_idx(
optimizer=optimizer,
scheduler=scheduler,
scaler=scaler,
sampler=sampler,
rank=rank,
)
def find_checkpoints(out_dir: Path) -> List[str]:
"""Find all available checkpoints in a directory.
The checkpoint filenames have the form: `checkpoint-xxx.pt`
where xxx is a numerical value.
Args:
out_dir:
The directory where to search for checkpoints.
Returns:
Return a list of checkpoint filenames, sorted in descending
order by the numerical value in the filename.
"""
checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
pattern = re.compile(r"checkpoint-([0-9]+).pt")
idx_checkpoints = [
(int(pattern.search(c).group(1)), c) for c in checkpoints
]
idx_checkpoints = sorted(idx_checkpoints, reverse=True, key=lambda x: x[0])
ans = [ic[1] for ic in idx_checkpoints]
return ans
def remove_checkpoints(
out_dir: Path,
topk: int,
@ -231,7 +264,7 @@ def remove_checkpoints(
assert topk >= 1, topk
if rank != 0:
return
checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
checkpoints = find_checkpoints(out_dir)
if len(checkpoints) == 0:
logging.warn(f"No checkpoints found in {out_dir}")
@ -240,13 +273,6 @@ def remove_checkpoints(
if len(checkpoints) <= topk:
return
pattern = re.compile(r"checkpoint-([0-9]+).pt")
idx_checkpoints = [
(int(pattern.search(c).group(1)), c) for c in checkpoints
]
idx_checkpoints = sorted(idx_checkpoints, reverse=True, key=lambda x: x[0])
to_remove = idx_checkpoints[topk:]
to_remove = [ic[1] for ic in to_remove]
to_remove = checkpoints[topk:]
for c in to_remove:
os.remove(c)