Periodically saving checkpoint after processing given number of batches (#259)

* Periodically saving checkpoint after processing given number of batches.
This commit is contained in:
Fangjun Kuang 2022-03-20 23:51:33 +08:00 committed by GitHub
parent 910e6c9306
commit ae564f91e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 267 additions and 24 deletions

View File

@ -58,7 +58,11 @@ from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import beam_search, greedy_search, modified_beam_search
from train import get_params, get_transducer_model from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
@ -88,6 +92,17 @@ def get_parser():
"'--epoch'. ", "'--epoch'. ",
) )
parser.add_argument(
"--avg-last-n",
type=int,
default=0,
help="""If positive, --epoch and --avg are ignored and it
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
where xxx is the number of processed batches while
saving that checkpoint.
""",
)
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
@ -372,7 +387,12 @@ def main():
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_transducer_model(params)
if params.avg == 1: if params.avg_last_n > 0:
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else: else:
start = params.epoch - params.avg + 1 start = params.epoch - params.avg + 1

View File

@ -35,7 +35,7 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional, Tuple from typing import Any, Dict, Optional, Tuple
import k2 import k2
import sentencepiece as spm import sentencepiece as spm
@ -47,6 +47,7 @@ from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from torch import Tensor from torch import Tensor
@ -55,8 +56,9 @@ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformer import Noam from transformer import Noam
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
@ -113,6 +115,15 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--start-batch",
type=int,
default=0,
help="""If positive, --start-epoch is ignored and
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
""",
)
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
@ -186,6 +197,30 @@ def get_parser():
help="The seed for random generators intended for reproducibility", help="The seed for random generators intended for reproducibility",
) )
parser.add_argument(
"--save-every-n",
type=int,
default=8000,
help="""Save checkpoint after processing this number of batches"
periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0.
""",
)
parser.add_argument(
"--keep-last-k",
type=int,
default=20,
help="""Only keep this number of checkpoints on disk.
For instance, if it is 3, there are only 3 checkpoints
in the exp-dir with filenames `checkpoint-xxx.pt`.
It does not affect checkpoints with name `epoch-xxx.pt`.
""",
)
return parser return parser
@ -314,15 +349,16 @@ def load_checkpoint_if_available(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, ) -> Optional[Dict[str, Any]]:
) -> None:
"""Load checkpoint from file. """Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from If params.start_batch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing. `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`.
Apart from loading state dict for `model`, `optimizer` and `scheduler`, Apart from loading state dict for `model` and `optimizer` it also updates
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`. and `best_valid_loss` in `params`.
Args: Args:
@ -332,20 +368,22 @@ def load_checkpoint_if_available(
The training model. The training model.
optimizer: optimizer:
The optimizer that we are using. The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns: Returns:
Return None. Return a dict containing previously saved training info.
""" """
if params.start_epoch <= 0: if params.start_batch > 0:
return filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
elif params.start_epoch > 0:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else:
return None
assert filename.is_file(), f"{filename} does not exist!"
saved_params = load_checkpoint( saved_params = load_checkpoint(
filename, filename,
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler,
) )
keys = [ keys = [
@ -354,10 +392,13 @@ def load_checkpoint_if_available(
"batch_idx_train", "batch_idx_train",
"best_train_loss", "best_train_loss",
"best_valid_loss", "best_valid_loss",
"cur_batch_idx",
] ]
for k in keys: for k in keys:
params[k] = saved_params[k] params[k] = saved_params[k]
params["start_epoch"] = saved_params["cur_epoch"]
return saved_params return saved_params
@ -365,7 +406,7 @@ def save_checkpoint(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, sampler: Optional[CutSampler] = None,
rank: int = 0, rank: int = 0,
) -> None: ) -> None:
"""Save model, optimizer, scheduler and training stats to file. """Save model, optimizer, scheduler and training stats to file.
@ -375,6 +416,10 @@ def save_checkpoint(
It is returned by :func:`get_params`. It is returned by :func:`get_params`.
model: model:
The training model. The training model.
optimizer:
The optimizer used in the training.
sampler:
The sampler for the training dataset.
""" """
if rank != 0: if rank != 0:
return return
@ -384,7 +429,7 @@ def save_checkpoint(
model=model, model=model,
params=params, params=params,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, sampler=sampler,
rank=rank, rank=rank,
) )
@ -500,6 +545,7 @@ def train_one_epoch(
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
rank: int = 0,
) -> None: ) -> None:
"""Train the model for one epoch. """Train the model for one epoch.
@ -522,6 +568,9 @@ def train_one_epoch(
Writer to write log messages to tensorboard. Writer to write log messages to tensorboard.
world_size: world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled. Number of nodes in DDP training. If it is 1, DDP is disabled.
rank:
The rank of the node in DDP training. If no DDP is used, it should
be set to 0.
""" """
model.train() model.train()
@ -566,7 +615,13 @@ def train_one_epoch(
else: else:
optimizer.step() optimizer.step()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -591,6 +646,27 @@ def train_one_epoch(
optimizer.zero_grad() optimizer.zero_grad()
if (
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
model=model,
params=params,
optimizer=optimizer,
sampler=train_dl.sampler,
rank=rank,
)
del params.cur_batch_idx
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
rank=rank,
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
logging.info( logging.info(
f"Epoch {params.cur_epoch}, " f"Epoch {params.cur_epoch}, "
@ -598,8 +674,6 @@ def train_one_epoch(
f"tot_loss[{tot_loss}], batch size: {batch_size}" f"tot_loss[{tot_loss}], batch size: {batch_size}"
) )
if batch_idx % params.log_interval == 0:
if tb_writer is not None: if tb_writer is not None:
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
@ -723,7 +797,14 @@ def run(rank, world_size, args):
logging.info(f"After removing short and long utterances: {num_left}") logging.info(f"After removing short and long utterances: {num_left}")
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
train_dl = librispeech.train_dataloaders(train_cuts) if checkpoints and "sampler" in checkpoints:
sampler_state_dict = checkpoints["sampler"]
else:
sampler_state_dict = None
train_dl = librispeech.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
valid_cuts = librispeech.dev_clean_cuts() valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts() valid_cuts += librispeech.dev_other_cuts()
@ -762,12 +843,14 @@ def run(rank, world_size, args):
valid_dl=valid_dl, valid_dl=valid_dl,
tb_writer=tb_writer, tb_writer=tb_writer,
world_size=world_size, world_size=world_size,
rank=rank,
) )
save_checkpoint( save_checkpoint(
params=params, params=params,
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
sampler=train_dl.sampler,
rank=rank, rank=rank,
) )

View File

@ -21,6 +21,7 @@ import inspect
import logging import logging
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional
from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import ( from lhotse.dataset import (
@ -181,8 +182,18 @@ class LibriSpeechAsrDataModule:
"with training dataset. ", "with training dataset. ",
) )
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
transforms = [] transforms = []
if self.args.enable_musan: if self.args.enable_musan:
logging.info("Enable MUSAN") logging.info("Enable MUSAN")
@ -286,6 +297,10 @@ class LibriSpeechAsrDataModule:
) )
logging.info("About to create train dataloader") logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
train_dl = DataLoader( train_dl = DataLoader(
train, train,
sampler=train_sampler, sampler=train_sampler,

View File

@ -15,12 +15,16 @@
# limitations under the License. # limitations under the License.
import glob
import logging import logging
import os
import re
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from lhotse.dataset.sampling.base import CutSampler
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer
@ -34,6 +38,7 @@ def save_checkpoint(
optimizer: Optional[Optimizer] = None, optimizer: Optional[Optimizer] = None,
scheduler: Optional[_LRScheduler] = None, scheduler: Optional[_LRScheduler] = None,
scaler: Optional[GradScaler] = None, scaler: Optional[GradScaler] = None,
sampler: Optional[CutSampler] = None,
rank: int = 0, rank: int = 0,
) -> None: ) -> None:
"""Save training information to a file. """Save training information to a file.
@ -69,6 +74,7 @@ def save_checkpoint(
"optimizer": optimizer.state_dict() if optimizer is not None else None, "optimizer": optimizer.state_dict() if optimizer is not None else None,
"scheduler": scheduler.state_dict() if scheduler is not None else None, "scheduler": scheduler.state_dict() if scheduler is not None else None,
"grad_scaler": scaler.state_dict() if scaler is not None else None, "grad_scaler": scaler.state_dict() if scaler is not None else None,
"sampler": sampler.state_dict() if sampler is not None else None,
} }
if params: if params:
@ -85,6 +91,7 @@ def load_checkpoint(
optimizer: Optional[Optimizer] = None, optimizer: Optional[Optimizer] = None,
scheduler: Optional[_LRScheduler] = None, scheduler: Optional[_LRScheduler] = None,
scaler: Optional[GradScaler] = None, scaler: Optional[GradScaler] = None,
sampler: Optional[CutSampler] = None,
strict: bool = False, strict: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
@ -117,6 +124,7 @@ def load_checkpoint(
load("optimizer", optimizer) load("optimizer", optimizer)
load("scheduler", scheduler) load("scheduler", scheduler)
load("grad_scaler", scaler) load("grad_scaler", scaler)
load("sampler", sampler)
return checkpoint return checkpoint
@ -151,3 +159,120 @@ def average_checkpoints(
avg[k] //= n avg[k] //= n
return avg return avg
def save_checkpoint_with_global_batch_idx(
out_dir: Path,
global_batch_idx: int,
model: Union[nn.Module, DDP],
params: Optional[Dict[str, Any]] = None,
optimizer: Optional[Optimizer] = None,
scheduler: Optional[_LRScheduler] = None,
scaler: Optional[GradScaler] = None,
sampler: Optional[CutSampler] = None,
rank: int = 0,
):
"""Save training info after processing given number of batches.
Args:
out_dir:
The directory to save the checkpoint.
global_batch_idx:
The number of batches processed so far from the very start of the
training. The saved checkpoint will have the following filename:
f'out_dir / checkpoint-{global_batch_idx}.pt'
model:
The neural network model whose `state_dict` will be saved in the
checkpoint.
params:
A dict of training configurations to be saved.
optimizer:
The optimizer used in the training. Its `state_dict` will be saved.
scheduler:
The learning rate scheduler used in the training. Its `state_dict` will
be saved.
scaler:
The scaler used for mix precision training. Its `state_dict` will
be saved.
sampler:
The sampler used in the training dataset.
rank:
The rank ID used in DDP training of the current node. Set it to 0
if DDP is not used.
"""
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
filename = out_dir / f"checkpoint-{global_batch_idx}.pt"
save_checkpoint(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
scaler=scaler,
sampler=sampler,
rank=rank,
)
def find_checkpoints(out_dir: Path) -> List[str]:
"""Find all available checkpoints in a directory.
The checkpoint filenames have the form: `checkpoint-xxx.pt`
where xxx is a numerical value.
Args:
out_dir:
The directory where to search for checkpoints.
Returns:
Return a list of checkpoint filenames, sorted in descending
order by the numerical value in the filename.
"""
checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
pattern = re.compile(r"checkpoint-([0-9]+).pt")
idx_checkpoints = [
(int(pattern.search(c).group(1)), c) for c in checkpoints
]
idx_checkpoints = sorted(idx_checkpoints, reverse=True, key=lambda x: x[0])
ans = [ic[1] for ic in idx_checkpoints]
return ans
def remove_checkpoints(
out_dir: Path,
topk: int,
rank: int = 0,
):
"""Remove checkpoints from the given directory.
We assume that checkpoint filename has the form `checkpoint-xxx.pt`
where xxx is a number, representing the number of processed batches
when saving that checkpoint. We sort checkpoints by filename and keep
only the `topk` checkpoints with the highest `xxx`.
Args:
out_dir:
The directory containing checkpoints to be removed.
topk:
Number of checkpoints to keep.
rank:
If using DDP for training, it is the rank of the current node.
Use 0 if no DDP is used for training.
"""
assert topk >= 1, topk
if rank != 0:
return
checkpoints = find_checkpoints(out_dir)
if len(checkpoints) == 0:
logging.warn(f"No checkpoints found in {out_dir}")
return
if len(checkpoints) <= topk:
return
to_remove = checkpoints[topk:]
for c in to_remove:
os.remove(c)