From ae564f91e6981321a715d3ce1ddf5dec5cc21296 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 20 Mar 2022 23:51:33 +0800 Subject: [PATCH] Periodically saving checkpoint after processing given number of batches (#259) * Periodically saving checkpoint after processing given number of batches. --- .../ASR/pruned_transducer_stateless/decode.py | 24 +++- .../ASR/pruned_transducer_stateless/train.py | 123 ++++++++++++++--- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 19 ++- icefall/checkpoint.py | 125 ++++++++++++++++++ 4 files changed, 267 insertions(+), 24 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 86ec6172f..fedf663b8 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -58,7 +58,11 @@ from asr_datamodule import LibriSpeechAsrDataModule from beam_search import beam_search, greedy_search, modified_beam_search from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -88,6 +92,17 @@ def get_parser(): "'--epoch'. ", ) + parser.add_argument( + "--avg-last-n", + type=int, + default=0, + help="""If positive, --epoch and --avg are ignored and it + will use the last n checkpoints exp_dir/checkpoint-xxx.pt + where xxx is the number of processed batches while + saving that checkpoint. + """, + ) + parser.add_argument( "--exp-dir", type=str, @@ -372,7 +387,12 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - if params.avg == 1: + if params.avg_last_n > 0: + filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: start = params.epoch - params.avg + 1 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index f0ea2ccaa..e71f0d1c6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -35,7 +35,7 @@ import argparse import logging from pathlib import Path from shutil import copyfile -from typing import Optional, Tuple +from typing import Any, Dict, Optional, Tuple import k2 import sentencepiece as spm @@ -47,6 +47,7 @@ from conformer import Conformer from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer from torch import Tensor @@ -55,8 +56,9 @@ from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam -from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import ( @@ -113,6 +115,15 @@ def get_parser(): """, ) + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + parser.add_argument( "--exp-dir", type=str, @@ -186,6 +197,30 @@ def get_parser(): help="The seed for random generators intended for reproducibility", ) + parser.add_argument( + "--save-every-n", + type=int, + default=8000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=20, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + return parser @@ -314,15 +349,16 @@ def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, -) -> None: +) -> Optional[Dict[str, Any]]: """Load checkpoint from file. - If params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. Otherwise, this function does nothing. + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. - Apart from loading state dict for `model`, `optimizer` and `scheduler`, - it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, and `best_valid_loss` in `params`. Args: @@ -332,20 +368,22 @@ def load_checkpoint_if_available( The training model. optimizer: The optimizer that we are using. - scheduler: - The learning rate scheduler we are using. Returns: - Return None. + Return a dict containing previously saved training info. """ - if params.start_epoch <= 0: - return + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 0: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" saved_params = load_checkpoint( filename, model=model, optimizer=optimizer, - scheduler=scheduler, ) keys = [ @@ -354,10 +392,13 @@ def load_checkpoint_if_available( "batch_idx_train", "best_train_loss", "best_valid_loss", + "cur_batch_idx", ] for k in keys: params[k] = saved_params[k] + params["start_epoch"] = saved_params["cur_epoch"] + return saved_params @@ -365,7 +406,7 @@ def save_checkpoint( params: AttributeDict, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + sampler: Optional[CutSampler] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -375,6 +416,10 @@ def save_checkpoint( It is returned by :func:`get_params`. model: The training model. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. """ if rank != 0: return @@ -384,7 +429,7 @@ def save_checkpoint( model=model, params=params, optimizer=optimizer, - scheduler=scheduler, + sampler=sampler, rank=rank, ) @@ -500,6 +545,7 @@ def train_one_epoch( valid_dl: torch.utils.data.DataLoader, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, + rank: int = 0, ) -> None: """Train the model for one epoch. @@ -522,6 +568,9 @@ def train_one_epoch( Writer to write log messages to tensorboard. world_size: Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. """ model.train() @@ -566,7 +615,13 @@ def train_one_epoch( else: optimizer.step() + cur_batch_idx = params.get("cur_batch_idx", 0) + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -591,6 +646,27 @@ def train_one_epoch( optimizer.zero_grad() + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + params=params, + optimizer=optimizer, + sampler=train_dl.sampler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + if batch_idx % params.log_interval == 0: logging.info( f"Epoch {params.cur_epoch}, " @@ -598,8 +674,6 @@ def train_one_epoch( f"tot_loss[{tot_loss}], batch size: {batch_size}" ) - if batch_idx % params.log_interval == 0: - if tb_writer is not None: loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train @@ -723,7 +797,14 @@ def run(rank, world_size, args): logging.info(f"After removing short and long utterances: {num_left}") logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") - train_dl = librispeech.train_dataloaders(train_cuts) + if checkpoints and "sampler" in checkpoints: + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) valid_cuts = librispeech.dev_clean_cuts() valid_cuts += librispeech.dev_other_cuts() @@ -762,12 +843,14 @@ def run(rank, world_size, args): valid_dl=valid_dl, tb_writer=tb_writer, world_size=world_size, + rank=rank, ) save_checkpoint( params=params, model=model, optimizer=optimizer, + sampler=train_dl.sampler, rank=rank, ) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 51e10fb2f..a460c8eb8 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -21,6 +21,7 @@ import inspect import logging from functools import lru_cache from pathlib import Path +from typing import Any, Dict, Optional from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse.dataset import ( @@ -181,8 +182,18 @@ class LibriSpeechAsrDataModule: "with training dataset. ", ) - def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: - + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") @@ -286,6 +297,10 @@ class LibriSpeechAsrDataModule: ) logging.info("About to create train dataloader") + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + train_dl = DataLoader( train, sampler=train_sampler, diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index dbe3c1315..251456c95 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -15,12 +15,16 @@ # limitations under the License. +import glob import logging +import os +import re from pathlib import Path from typing import Any, Dict, List, Optional, Union import torch import torch.nn as nn +from lhotse.dataset.sampling.base import CutSampler from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer @@ -34,6 +38,7 @@ def save_checkpoint( optimizer: Optional[Optimizer] = None, scheduler: Optional[_LRScheduler] = None, scaler: Optional[GradScaler] = None, + sampler: Optional[CutSampler] = None, rank: int = 0, ) -> None: """Save training information to a file. @@ -69,6 +74,7 @@ def save_checkpoint( "optimizer": optimizer.state_dict() if optimizer is not None else None, "scheduler": scheduler.state_dict() if scheduler is not None else None, "grad_scaler": scaler.state_dict() if scaler is not None else None, + "sampler": sampler.state_dict() if sampler is not None else None, } if params: @@ -85,6 +91,7 @@ def load_checkpoint( optimizer: Optional[Optimizer] = None, scheduler: Optional[_LRScheduler] = None, scaler: Optional[GradScaler] = None, + sampler: Optional[CutSampler] = None, strict: bool = False, ) -> Dict[str, Any]: """ @@ -117,6 +124,7 @@ def load_checkpoint( load("optimizer", optimizer) load("scheduler", scheduler) load("grad_scaler", scaler) + load("sampler", sampler) return checkpoint @@ -151,3 +159,120 @@ def average_checkpoints( avg[k] //= n return avg + + +def save_checkpoint_with_global_batch_idx( + out_dir: Path, + global_batch_idx: int, + model: Union[nn.Module, DDP], + params: Optional[Dict[str, Any]] = None, + optimizer: Optional[Optimizer] = None, + scheduler: Optional[_LRScheduler] = None, + scaler: Optional[GradScaler] = None, + sampler: Optional[CutSampler] = None, + rank: int = 0, +): + """Save training info after processing given number of batches. + + Args: + out_dir: + The directory to save the checkpoint. + global_batch_idx: + The number of batches processed so far from the very start of the + training. The saved checkpoint will have the following filename: + + f'out_dir / checkpoint-{global_batch_idx}.pt' + model: + The neural network model whose `state_dict` will be saved in the + checkpoint. + params: + A dict of training configurations to be saved. + optimizer: + The optimizer used in the training. Its `state_dict` will be saved. + scheduler: + The learning rate scheduler used in the training. Its `state_dict` will + be saved. + scaler: + The scaler used for mix precision training. Its `state_dict` will + be saved. + sampler: + The sampler used in the training dataset. + rank: + The rank ID used in DDP training of the current node. Set it to 0 + if DDP is not used. + """ + out_dir = Path(out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + filename = out_dir / f"checkpoint-{global_batch_idx}.pt" + save_checkpoint( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + sampler=sampler, + rank=rank, + ) + + +def find_checkpoints(out_dir: Path) -> List[str]: + """Find all available checkpoints in a directory. + + The checkpoint filenames have the form: `checkpoint-xxx.pt` + where xxx is a numerical value. + + Args: + out_dir: + The directory where to search for checkpoints. + Returns: + Return a list of checkpoint filenames, sorted in descending + order by the numerical value in the filename. + """ + checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt")) + pattern = re.compile(r"checkpoint-([0-9]+).pt") + idx_checkpoints = [ + (int(pattern.search(c).group(1)), c) for c in checkpoints + ] + + idx_checkpoints = sorted(idx_checkpoints, reverse=True, key=lambda x: x[0]) + ans = [ic[1] for ic in idx_checkpoints] + return ans + + +def remove_checkpoints( + out_dir: Path, + topk: int, + rank: int = 0, +): + """Remove checkpoints from the given directory. + + We assume that checkpoint filename has the form `checkpoint-xxx.pt` + where xxx is a number, representing the number of processed batches + when saving that checkpoint. We sort checkpoints by filename and keep + only the `topk` checkpoints with the highest `xxx`. + + Args: + out_dir: + The directory containing checkpoints to be removed. + topk: + Number of checkpoints to keep. + rank: + If using DDP for training, it is the rank of the current node. + Use 0 if no DDP is used for training. + """ + assert topk >= 1, topk + if rank != 0: + return + checkpoints = find_checkpoints(out_dir) + + if len(checkpoints) == 0: + logging.warn(f"No checkpoints found in {out_dir}") + return + + if len(checkpoints) <= topk: + return + + to_remove = checkpoints[topk:] + for c in to_remove: + os.remove(c)