Save averaged models periodically during training

This commit is contained in:
Fangjun Kuang 2022-05-23 18:15:55 +08:00
parent bf3df442c6
commit b7676ca1f2

View File

@ -43,6 +43,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
""" """
import argparse import argparse
import copy
import logging import logging
import warnings import warnings
from pathlib import Path from pathlib import Path
@ -72,7 +73,10 @@ from torch.utils.tensorboard import SummaryWriter
from icefall import diagnostics from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.checkpoint import (
save_checkpoint_with_global_batch_idx,
update_averaged_model,
)
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
@ -126,10 +130,10 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--start-epoch", "--start-epoch",
type=int, type=int,
default=0, default=1,
help="""Resume training from from this epoch. help="""Resume training from this epoch. It should be positive.
If it is positive, it will load checkpoint from If larger than 1, it will load checkpoint from
transducer_stateless2/exp/epoch-{start_epoch-1}.pt exp-dir/epoch-{start_epoch-1}.pt
""", """,
) )
@ -163,15 +167,16 @@ def get_parser():
"--initial-lr", "--initial-lr",
type=float, type=float,
default=0.003, default=0.003,
help="The initial learning rate. This value should not need to be changed.", help="The initial learning rate. This value should not need to "
"be changed.",
) )
parser.add_argument( parser.add_argument(
"--lr-batches", "--lr-batches",
type=float, type=float,
default=5000, default=5000,
help="""Number of steps that affects how rapidly the learning rate decreases. help="""Number of steps that affects how rapidly the learning
We suggest not to change this.""", rate decreases. We suggest not to change this.""",
) )
parser.add_argument( parser.add_argument(
@ -249,7 +254,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--save-every-n", "--save-every-n",
type=int, type=int,
default=8000, default=4000,
help="""Save checkpoint after processing this number of batches" help="""Save checkpoint after processing this number of batches"
periodically. We save checkpoint to exp-dir/ whenever periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename params.batch_idx_train % save_every_n == 0. The checkpoint filename
@ -262,7 +267,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--keep-last-k", "--keep-last-k",
type=int, type=int,
default=20, default=30,
help="""Only keep this number of checkpoints on disk. help="""Only keep this number of checkpoints on disk.
For instance, if it is 3, there are only 3 checkpoints For instance, if it is 3, there are only 3 checkpoints
in the exp-dir with filenames `checkpoint-xxx.pt`. in the exp-dir with filenames `checkpoint-xxx.pt`.
@ -270,6 +275,19 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--average-period",
type=int,
default=100,
help="""Update the averaged model, namely `model_avg`, after processing
this number of batches. `model_avg` is a separate version of model,
in which each floating-point parameter is the average of all the
parameters from the start of training. Each time we take the average,
we do: `model_avg = model * (average_period / batch_idx_train) +
model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
""",
)
parser.add_argument( parser.add_argument(
"--use-fp16", "--use-fp16",
type=str2bool, type=str2bool,
@ -408,6 +426,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
def load_checkpoint_if_available( def load_checkpoint_if_available(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
model_avg: Optional[nn.Module] = None,
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None, scheduler: Optional[LRSchedulerType] = None,
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
@ -415,7 +434,7 @@ def load_checkpoint_if_available(
If params.start_batch is positive, it will load the checkpoint from If params.start_batch is positive, it will load the checkpoint from
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
params.start_epoch is positive, it will load the checkpoint from params.start_epoch is larger than 1, it will load the checkpoint from
`params.start_epoch - 1`. `params.start_epoch - 1`.
Apart from loading state dict for `model` and `optimizer` it also updates Apart from loading state dict for `model` and `optimizer` it also updates
@ -427,6 +446,8 @@ def load_checkpoint_if_available(
The return value of :func:`get_params`. The return value of :func:`get_params`.
model: model:
The training model. The training model.
model_avg:
The stored model averaged from the start of training.
optimizer: optimizer:
The optimizer that we are using. The optimizer that we are using.
scheduler: scheduler:
@ -436,7 +457,7 @@ def load_checkpoint_if_available(
""" """
if params.start_batch > 0: if params.start_batch > 0:
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
elif params.start_epoch > 0: elif params.start_epoch > 1:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else: else:
return None return None
@ -446,6 +467,7 @@ def load_checkpoint_if_available(
saved_params = load_checkpoint( saved_params = load_checkpoint(
filename, filename,
model=model, model=model,
model_avg=model_avg,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, scheduler=scheduler,
) )
@ -472,7 +494,8 @@ def load_checkpoint_if_available(
def save_checkpoint( def save_checkpoint(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: Union[nn.Module, DDP],
model_avg: Optional[nn.Module] = None,
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None, scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None, sampler: Optional[CutSampler] = None,
@ -486,6 +509,8 @@ def save_checkpoint(
It is returned by :func:`get_params`. It is returned by :func:`get_params`.
model: model:
The training model. The training model.
model_avg:
The stored model averaged from the start of training.
optimizer: optimizer:
The optimizer used in the training. The optimizer used in the training.
sampler: sampler:
@ -499,6 +524,7 @@ def save_checkpoint(
save_checkpoint_impl( save_checkpoint_impl(
filename=filename, filename=filename,
model=model, model=model,
model_avg=model_avg,
params=params, params=params,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, scheduler=scheduler,
@ -539,7 +565,7 @@ def compute_loss(
function enables autograd during computation; when it is False, it function enables autograd during computation; when it is False, it
disables autograd. disables autograd.
""" """
device = model.device device = params.device
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
@ -624,6 +650,7 @@ def train_one_epoch(
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler, scaler: GradScaler,
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
rank: int = 0, rank: int = 0,
@ -649,6 +676,8 @@ def train_one_epoch(
Dataloader for the validation dataset. Dataloader for the validation dataset.
scaler: scaler:
The scaler used for mix precision training. The scaler used for mix precision training.
model_avg:
The stored model averaged from the start of training.
tb_writer: tb_writer:
Writer to write log messages to tensorboard. Writer to write log messages to tensorboard.
world_size: world_size:
@ -695,51 +724,68 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): try:
loss, loss_info = compute_loss( with torch.cuda.amp.autocast(enabled=params.use_fp16):
params=params, loss, loss_info = compute_loss(
model=model, params=params,
sp=sp, model=model,
batch=batch, sp=sp,
is_training=True, batch=batch,
) is_training=True,
# summary stats )
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info # summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances # NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far. # in the batch and there is no normalization to it so far.
scaler.scale(loss).backward() scaler.scale(loss).backward()
maybe_log_weights("train/param_norms") maybe_log_weights("train/param_norms")
maybe_log_gradients("train/grad_norms") maybe_log_gradients("train/grad_norms")
old_parameters = None old_parameters = None
if ( if (
params.log_diagnostics params.log_diagnostics
and tb_writer is not None and tb_writer is not None
and params.batch_idx_train % (params.log_interval * 5) == 0 and params.batch_idx_train % (params.log_interval * 5) == 0
): ):
old_parameters = { old_parameters = {
n: p.detach().clone() for n, p in model.named_parameters() n: p.detach().clone() for n, p in model.named_parameters()
} }
scheduler.step_batch(params.batch_idx_train) scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
if old_parameters is not None: if old_parameters is not None:
deltas = optim_step_and_measure_param_change(model, old_parameters) deltas = optim_step_and_measure_param_change(
tb_writer.add_scalars( model, old_parameters
"train/relative_param_change_per_minibatch", )
deltas, tb_writer.add_scalars(
global_step=params.batch_idx_train, "train/relative_param_change_per_minibatch",
) deltas,
global_step=params.batch_idx_train,
)
optimizer.zero_grad() optimizer.zero_grad()
except: # noqa
display_and_save_batch(batch, params=params, sp=sp)
raise
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
return return
if (
rank == 0
and params.batch_idx_train > 0
and params.batch_idx_train % params.average_period == 0
):
update_averaged_model(
params=params,
model_cur=model,
model_avg=model_avg,
)
if ( if (
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
@ -749,6 +795,7 @@ def train_one_epoch(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
model=model, model=model,
model_avg=model_avg,
params=params, params=params,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, scheduler=scheduler,
@ -841,6 +888,8 @@ def run(rank, world_size, args):
device = torch.device("cuda", rank) device = torch.device("cuda", rank)
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
params.device = device
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model) sp.load(params.bpe_model)
@ -856,13 +905,23 @@ def run(rank, world_size, args):
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
checkpoints = load_checkpoint_if_available(params=params, model=model) assert params.save_every_n >= params.average_period
model_avg: Optional[nn.Module] = None
if rank == 0:
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model)
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(
params=params,
model=model,
model_avg=model_avg,
)
model.to(device) model.to(device)
if world_size > 1: if world_size > 1:
logging.info("Using DDP") logging.info("Using DDP")
model = DDP(model, device_ids=[rank]) model = DDP(model, device_ids=[rank])
model.device = device
optimizer = Eve(model.parameters(), lr=params.initial_lr) optimizer = Eve(model.parameters(), lr=params.initial_lr)
@ -935,10 +994,10 @@ def run(rank, world_size, args):
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs): for epoch in range(params.start_epoch, params.num_epochs + 1):
scheduler.step_epoch(epoch) scheduler.step_epoch(epoch - 1)
fix_random_seed(params.seed + epoch) fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch) train_dl.sampler.set_epoch(epoch - 1)
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
@ -948,6 +1007,7 @@ def run(rank, world_size, args):
train_one_epoch( train_one_epoch(
params=params, params=params,
model=model, model=model,
model_avg=model_avg,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, scheduler=scheduler,
sp=sp, sp=sp,
@ -966,6 +1026,7 @@ def run(rank, world_size, args):
save_checkpoint( save_checkpoint(
params=params, params=params,
model=model, model=model,
model_avg=model_avg,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, scheduler=scheduler,
sampler=train_dl.sampler, sampler=train_dl.sampler,
@ -980,6 +1041,38 @@ def run(rank, world_size, args):
cleanup_dist() cleanup_dist()
def display_and_save_batch(
batch: dict,
params: AttributeDict,
sp: spm.SentencePieceProcessor,
) -> None:
"""Display the batch statistics and save the batch into disk.
Args:
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
params:
Parameters for training. See :func:`get_params`.
sp:
The BPE model.
"""
from lhotse.utils import uuid4
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
logging.info(f"Saving batch to {filename}")
torch.save(batch, filename)
supervisions = batch["supervisions"]
features = batch["inputs"]
logging.info(f"features shape: {features.shape}")
y = sp.encode(supervisions["text"], out_type=int)
num_tokens = sum(len(i) for i in y)
logging.info(f"num tokens: {num_tokens}")
def scan_pessimistic_batches_for_oom( def scan_pessimistic_batches_for_oom(
model: nn.Module, model: nn.Module,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
@ -1016,6 +1109,7 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} " f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..." f"(={crit_values[criterion]}) ..."
) )
display_and_save_batch(batch, params=params, sp=sp)
raise raise