mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Periodically saving checkpoint after processing given number of batches (#259)
* Periodically saving checkpoint after processing given number of batches.
This commit is contained in:
parent
910e6c9306
commit
ae564f91e6
@ -58,7 +58,11 @@ from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import beam_search, greedy_search, modified_beam_search
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
@ -88,6 +92,17 @@ def get_parser():
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg-last-n",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch and --avg are ignored and it
|
||||
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
|
||||
where xxx is the number of processed batches while
|
||||
saving that checkpoint.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
@ -372,7 +387,12 @@ def main():
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if params.avg == 1:
|
||||
if params.avg_last_n > 0:
|
||||
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
|
@ -35,7 +35,7 @@ import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
@ -47,6 +47,7 @@ from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from torch import Tensor
|
||||
@ -55,8 +56,9 @@ from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformer import Noam
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
@ -113,6 +115,15 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start-batch",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --start-epoch is ignored and
|
||||
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
@ -186,6 +197,30 @@ def get_parser():
|
||||
help="The seed for random generators intended for reproducibility",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--save-every-n",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="""Save checkpoint after processing this number of batches"
|
||||
periodically. We save checkpoint to exp-dir/ whenever
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--keep-last-k",
|
||||
type=int,
|
||||
default=20,
|
||||
help="""Only keep this number of checkpoints on disk.
|
||||
For instance, if it is 3, there are only 3 checkpoints
|
||||
in the exp-dir with filenames `checkpoint-xxx.pt`.
|
||||
It does not affect checkpoints with name `epoch-xxx.pt`.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -314,15 +349,16 @@ def load_checkpoint_if_available(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||
) -> None:
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Load checkpoint from file.
|
||||
|
||||
If params.start_epoch is positive, it will load the checkpoint from
|
||||
`params.start_epoch - 1`. Otherwise, this function does nothing.
|
||||
If params.start_batch is positive, it will load the checkpoint from
|
||||
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
|
||||
params.start_epoch is positive, it will load the checkpoint from
|
||||
`params.start_epoch - 1`.
|
||||
|
||||
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
|
||||
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
||||
Apart from loading state dict for `model` and `optimizer` it also updates
|
||||
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
||||
and `best_valid_loss` in `params`.
|
||||
|
||||
Args:
|
||||
@ -332,20 +368,22 @@ def load_checkpoint_if_available(
|
||||
The training model.
|
||||
optimizer:
|
||||
The optimizer that we are using.
|
||||
scheduler:
|
||||
The learning rate scheduler we are using.
|
||||
Returns:
|
||||
Return None.
|
||||
Return a dict containing previously saved training info.
|
||||
"""
|
||||
if params.start_epoch <= 0:
|
||||
return
|
||||
if params.start_batch > 0:
|
||||
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
|
||||
elif params.start_epoch > 0:
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
else:
|
||||
return None
|
||||
|
||||
assert filename.is_file(), f"{filename} does not exist!"
|
||||
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
saved_params = load_checkpoint(
|
||||
filename,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
keys = [
|
||||
@ -354,10 +392,13 @@ def load_checkpoint_if_available(
|
||||
"batch_idx_train",
|
||||
"best_train_loss",
|
||||
"best_valid_loss",
|
||||
"cur_batch_idx",
|
||||
]
|
||||
for k in keys:
|
||||
params[k] = saved_params[k]
|
||||
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -365,7 +406,7 @@ def save_checkpoint(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -375,6 +416,10 @@ def save_checkpoint(
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
optimizer:
|
||||
The optimizer used in the training.
|
||||
sampler:
|
||||
The sampler for the training dataset.
|
||||
"""
|
||||
if rank != 0:
|
||||
return
|
||||
@ -384,7 +429,7 @@ def save_checkpoint(
|
||||
model=model,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sampler=sampler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
@ -500,6 +545,7 @@ def train_one_epoch(
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Train the model for one epoch.
|
||||
|
||||
@ -522,6 +568,9 @@ def train_one_epoch(
|
||||
Writer to write log messages to tensorboard.
|
||||
world_size:
|
||||
Number of nodes in DDP training. If it is 1, DDP is disabled.
|
||||
rank:
|
||||
The rank of the node in DDP training. If no DDP is used, it should
|
||||
be set to 0.
|
||||
"""
|
||||
model.train()
|
||||
|
||||
@ -566,7 +615,13 @@ def train_one_epoch(
|
||||
else:
|
||||
optimizer.step()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -591,6 +646,27 @@ def train_one_epoch(
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
if (
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
model=model,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
sampler=train_dl.sampler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
@ -598,8 +674,6 @@ def train_one_epoch(
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}"
|
||||
)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
|
||||
if tb_writer is not None:
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
@ -723,7 +797,14 @@ def run(rank, world_size, args):
|
||||
logging.info(f"After removing short and long utterances: {num_left}")
|
||||
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
|
||||
|
||||
train_dl = librispeech.train_dataloaders(train_cuts)
|
||||
if checkpoints and "sampler" in checkpoints:
|
||||
sampler_state_dict = checkpoints["sampler"]
|
||||
else:
|
||||
sampler_state_dict = None
|
||||
|
||||
train_dl = librispeech.train_dataloaders(
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
|
||||
valid_cuts = librispeech.dev_clean_cuts()
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
@ -762,12 +843,14 @@ def run(rank, world_size, args):
|
||||
valid_dl=valid_dl,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
save_checkpoint(
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
sampler=train_dl.sampler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
|
@ -21,6 +21,7 @@ import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
|
||||
from lhotse.dataset import (
|
||||
@ -181,8 +182,18 @@ class LibriSpeechAsrDataModule:
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
@ -286,6 +297,10 @@ class LibriSpeechAsrDataModule:
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
|
@ -15,12 +15,16 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
@ -34,6 +38,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
scheduler: Optional[_LRScheduler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save training information to a file.
|
||||
@ -69,6 +74,7 @@ def save_checkpoint(
|
||||
"optimizer": optimizer.state_dict() if optimizer is not None else None,
|
||||
"scheduler": scheduler.state_dict() if scheduler is not None else None,
|
||||
"grad_scaler": scaler.state_dict() if scaler is not None else None,
|
||||
"sampler": sampler.state_dict() if sampler is not None else None,
|
||||
}
|
||||
|
||||
if params:
|
||||
@ -85,6 +91,7 @@ def load_checkpoint(
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
scheduler: Optional[_LRScheduler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
strict: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -117,6 +124,7 @@ def load_checkpoint(
|
||||
load("optimizer", optimizer)
|
||||
load("scheduler", scheduler)
|
||||
load("grad_scaler", scaler)
|
||||
load("sampler", sampler)
|
||||
|
||||
return checkpoint
|
||||
|
||||
@ -151,3 +159,120 @@ def average_checkpoints(
|
||||
avg[k] //= n
|
||||
|
||||
return avg
|
||||
|
||||
|
||||
def save_checkpoint_with_global_batch_idx(
|
||||
out_dir: Path,
|
||||
global_batch_idx: int,
|
||||
model: Union[nn.Module, DDP],
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
scheduler: Optional[_LRScheduler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
rank: int = 0,
|
||||
):
|
||||
"""Save training info after processing given number of batches.
|
||||
|
||||
Args:
|
||||
out_dir:
|
||||
The directory to save the checkpoint.
|
||||
global_batch_idx:
|
||||
The number of batches processed so far from the very start of the
|
||||
training. The saved checkpoint will have the following filename:
|
||||
|
||||
f'out_dir / checkpoint-{global_batch_idx}.pt'
|
||||
model:
|
||||
The neural network model whose `state_dict` will be saved in the
|
||||
checkpoint.
|
||||
params:
|
||||
A dict of training configurations to be saved.
|
||||
optimizer:
|
||||
The optimizer used in the training. Its `state_dict` will be saved.
|
||||
scheduler:
|
||||
The learning rate scheduler used in the training. Its `state_dict` will
|
||||
be saved.
|
||||
scaler:
|
||||
The scaler used for mix precision training. Its `state_dict` will
|
||||
be saved.
|
||||
sampler:
|
||||
The sampler used in the training dataset.
|
||||
rank:
|
||||
The rank ID used in DDP training of the current node. Set it to 0
|
||||
if DDP is not used.
|
||||
"""
|
||||
out_dir = Path(out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
filename = out_dir / f"checkpoint-{global_batch_idx}.pt"
|
||||
save_checkpoint(
|
||||
filename=filename,
|
||||
model=model,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
scaler=scaler,
|
||||
sampler=sampler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
|
||||
def find_checkpoints(out_dir: Path) -> List[str]:
|
||||
"""Find all available checkpoints in a directory.
|
||||
|
||||
The checkpoint filenames have the form: `checkpoint-xxx.pt`
|
||||
where xxx is a numerical value.
|
||||
|
||||
Args:
|
||||
out_dir:
|
||||
The directory where to search for checkpoints.
|
||||
Returns:
|
||||
Return a list of checkpoint filenames, sorted in descending
|
||||
order by the numerical value in the filename.
|
||||
"""
|
||||
checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
|
||||
pattern = re.compile(r"checkpoint-([0-9]+).pt")
|
||||
idx_checkpoints = [
|
||||
(int(pattern.search(c).group(1)), c) for c in checkpoints
|
||||
]
|
||||
|
||||
idx_checkpoints = sorted(idx_checkpoints, reverse=True, key=lambda x: x[0])
|
||||
ans = [ic[1] for ic in idx_checkpoints]
|
||||
return ans
|
||||
|
||||
|
||||
def remove_checkpoints(
|
||||
out_dir: Path,
|
||||
topk: int,
|
||||
rank: int = 0,
|
||||
):
|
||||
"""Remove checkpoints from the given directory.
|
||||
|
||||
We assume that checkpoint filename has the form `checkpoint-xxx.pt`
|
||||
where xxx is a number, representing the number of processed batches
|
||||
when saving that checkpoint. We sort checkpoints by filename and keep
|
||||
only the `topk` checkpoints with the highest `xxx`.
|
||||
|
||||
Args:
|
||||
out_dir:
|
||||
The directory containing checkpoints to be removed.
|
||||
topk:
|
||||
Number of checkpoints to keep.
|
||||
rank:
|
||||
If using DDP for training, it is the rank of the current node.
|
||||
Use 0 if no DDP is used for training.
|
||||
"""
|
||||
assert topk >= 1, topk
|
||||
if rank != 0:
|
||||
return
|
||||
checkpoints = find_checkpoints(out_dir)
|
||||
|
||||
if len(checkpoints) == 0:
|
||||
logging.warn(f"No checkpoints found in {out_dir}")
|
||||
return
|
||||
|
||||
if len(checkpoints) <= topk:
|
||||
return
|
||||
|
||||
to_remove = checkpoints[topk:]
|
||||
for c in to_remove:
|
||||
os.remove(c)
|
||||
|
Loading…
x
Reference in New Issue
Block a user