mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Save averaged models periodically during training
This commit is contained in:
parent
bf3df442c6
commit
b7676ca1f2
@ -43,6 +43,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
@ -72,7 +73,10 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
from icefall import diagnostics
|
||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
||||
from icefall.checkpoint import (
|
||||
save_checkpoint_with_global_batch_idx,
|
||||
update_averaged_model,
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
@ -126,10 +130,10 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""Resume training from from this epoch.
|
||||
If it is positive, it will load checkpoint from
|
||||
transducer_stateless2/exp/epoch-{start_epoch-1}.pt
|
||||
default=1,
|
||||
help="""Resume training from this epoch. It should be positive.
|
||||
If larger than 1, it will load checkpoint from
|
||||
exp-dir/epoch-{start_epoch-1}.pt
|
||||
""",
|
||||
)
|
||||
|
||||
@ -163,15 +167,16 @@ def get_parser():
|
||||
"--initial-lr",
|
||||
type=float,
|
||||
default=0.003,
|
||||
help="The initial learning rate. This value should not need to be changed.",
|
||||
help="The initial learning rate. This value should not need to "
|
||||
"be changed.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lr-batches",
|
||||
type=float,
|
||||
default=5000,
|
||||
help="""Number of steps that affects how rapidly the learning rate decreases.
|
||||
We suggest not to change this.""",
|
||||
help="""Number of steps that affects how rapidly the learning
|
||||
rate decreases. We suggest not to change this.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -249,7 +254,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--save-every-n",
|
||||
type=int,
|
||||
default=8000,
|
||||
default=4000,
|
||||
help="""Save checkpoint after processing this number of batches"
|
||||
periodically. We save checkpoint to exp-dir/ whenever
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
@ -262,7 +267,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--keep-last-k",
|
||||
type=int,
|
||||
default=20,
|
||||
default=30,
|
||||
help="""Only keep this number of checkpoints on disk.
|
||||
For instance, if it is 3, there are only 3 checkpoints
|
||||
in the exp-dir with filenames `checkpoint-xxx.pt`.
|
||||
@ -270,6 +275,19 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--average-period",
|
||||
type=int,
|
||||
default=100,
|
||||
help="""Update the averaged model, namely `model_avg`, after processing
|
||||
this number of batches. `model_avg` is a separate version of model,
|
||||
in which each floating-point parameter is the average of all the
|
||||
parameters from the start of training. Each time we take the average,
|
||||
we do: `model_avg = model * (average_period / batch_idx_train) +
|
||||
model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-fp16",
|
||||
type=str2bool,
|
||||
@ -408,6 +426,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
def load_checkpoint_if_available(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
@ -415,7 +434,7 @@ def load_checkpoint_if_available(
|
||||
|
||||
If params.start_batch is positive, it will load the checkpoint from
|
||||
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
|
||||
params.start_epoch is positive, it will load the checkpoint from
|
||||
params.start_epoch is larger than 1, it will load the checkpoint from
|
||||
`params.start_epoch - 1`.
|
||||
|
||||
Apart from loading state dict for `model` and `optimizer` it also updates
|
||||
@ -427,6 +446,8 @@ def load_checkpoint_if_available(
|
||||
The return value of :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
model_avg:
|
||||
The stored model averaged from the start of training.
|
||||
optimizer:
|
||||
The optimizer that we are using.
|
||||
scheduler:
|
||||
@ -436,7 +457,7 @@ def load_checkpoint_if_available(
|
||||
"""
|
||||
if params.start_batch > 0:
|
||||
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
|
||||
elif params.start_epoch > 0:
|
||||
elif params.start_epoch > 1:
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
else:
|
||||
return None
|
||||
@ -446,6 +467,7 @@ def load_checkpoint_if_available(
|
||||
saved_params = load_checkpoint(
|
||||
filename,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
@ -472,7 +494,8 @@ def load_checkpoint_if_available(
|
||||
|
||||
def save_checkpoint(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
model: Union[nn.Module, DDP],
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
@ -486,6 +509,8 @@ def save_checkpoint(
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
model_avg:
|
||||
The stored model averaged from the start of training.
|
||||
optimizer:
|
||||
The optimizer used in the training.
|
||||
sampler:
|
||||
@ -499,6 +524,7 @@ def save_checkpoint(
|
||||
save_checkpoint_impl(
|
||||
filename=filename,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
@ -539,7 +565,7 @@ def compute_loss(
|
||||
function enables autograd during computation; when it is False, it
|
||||
disables autograd.
|
||||
"""
|
||||
device = model.device
|
||||
device = params.device
|
||||
feature = batch["inputs"]
|
||||
# at entry, feature is (N, T, C)
|
||||
assert feature.ndim == 3
|
||||
@ -624,6 +650,7 @@ def train_one_epoch(
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
@ -649,6 +676,8 @@ def train_one_epoch(
|
||||
Dataloader for the validation dataset.
|
||||
scaler:
|
||||
The scaler used for mix precision training.
|
||||
model_avg:
|
||||
The stored model averaged from the start of training.
|
||||
tb_writer:
|
||||
Writer to write log messages to tensorboard.
|
||||
world_size:
|
||||
@ -695,51 +724,68 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
)
|
||||
# summary stats
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
)
|
||||
# summary stats
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
scaler.scale(loss).backward()
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
maybe_log_weights("train/param_norms")
|
||||
maybe_log_gradients("train/grad_norms")
|
||||
maybe_log_weights("train/param_norms")
|
||||
maybe_log_gradients("train/grad_norms")
|
||||
|
||||
old_parameters = None
|
||||
if (
|
||||
params.log_diagnostics
|
||||
and tb_writer is not None
|
||||
and params.batch_idx_train % (params.log_interval * 5) == 0
|
||||
):
|
||||
old_parameters = {
|
||||
n: p.detach().clone() for n, p in model.named_parameters()
|
||||
}
|
||||
old_parameters = None
|
||||
if (
|
||||
params.log_diagnostics
|
||||
and tb_writer is not None
|
||||
and params.batch_idx_train % (params.log_interval * 5) == 0
|
||||
):
|
||||
old_parameters = {
|
||||
n: p.detach().clone() for n, p in model.named_parameters()
|
||||
}
|
||||
|
||||
scheduler.step_batch(params.batch_idx_train)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
scheduler.step_batch(params.batch_idx_train)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
if old_parameters is not None:
|
||||
deltas = optim_step_and_measure_param_change(model, old_parameters)
|
||||
tb_writer.add_scalars(
|
||||
"train/relative_param_change_per_minibatch",
|
||||
deltas,
|
||||
global_step=params.batch_idx_train,
|
||||
)
|
||||
if old_parameters is not None:
|
||||
deltas = optim_step_and_measure_param_change(
|
||||
model, old_parameters
|
||||
)
|
||||
tb_writer.add_scalars(
|
||||
"train/relative_param_change_per_minibatch",
|
||||
deltas,
|
||||
global_step=params.batch_idx_train,
|
||||
)
|
||||
|
||||
optimizer.zero_grad()
|
||||
optimizer.zero_grad()
|
||||
except: # noqa
|
||||
display_and_save_batch(batch, params=params, sp=sp)
|
||||
raise
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
return
|
||||
|
||||
if (
|
||||
rank == 0
|
||||
and params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.average_period == 0
|
||||
):
|
||||
update_averaged_model(
|
||||
params=params,
|
||||
model_cur=model,
|
||||
model_avg=model_avg,
|
||||
)
|
||||
|
||||
if (
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
@ -749,6 +795,7 @@ def train_one_epoch(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
@ -841,6 +888,8 @@ def run(rank, world_size, args):
|
||||
device = torch.device("cuda", rank)
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
params.device = device
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
@ -856,13 +905,23 @@ def run(rank, world_size, args):
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||
assert params.save_every_n >= params.average_period
|
||||
model_avg: Optional[nn.Module] = None
|
||||
if rank == 0:
|
||||
# model_avg is only used with rank 0
|
||||
model_avg = copy.deepcopy(model)
|
||||
|
||||
assert params.start_epoch > 0, params.start_epoch
|
||||
checkpoints = load_checkpoint_if_available(
|
||||
params=params,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
logging.info("Using DDP")
|
||||
model = DDP(model, device_ids=[rank])
|
||||
model.device = device
|
||||
|
||||
optimizer = Eve(model.parameters(), lr=params.initial_lr)
|
||||
|
||||
@ -935,10 +994,10 @@ def run(rank, world_size, args):
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs):
|
||||
scheduler.step_epoch(epoch)
|
||||
fix_random_seed(params.seed + epoch)
|
||||
train_dl.sampler.set_epoch(epoch)
|
||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||
scheduler.step_epoch(epoch - 1)
|
||||
fix_random_seed(params.seed + epoch - 1)
|
||||
train_dl.sampler.set_epoch(epoch - 1)
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
@ -948,6 +1007,7 @@ def run(rank, world_size, args):
|
||||
train_one_epoch(
|
||||
params=params,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sp=sp,
|
||||
@ -966,6 +1026,7 @@ def run(rank, world_size, args):
|
||||
save_checkpoint(
|
||||
params=params,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sampler=train_dl.sampler,
|
||||
@ -980,6 +1041,38 @@ def run(rank, world_size, args):
|
||||
cleanup_dist()
|
||||
|
||||
|
||||
def display_and_save_batch(
|
||||
batch: dict,
|
||||
params: AttributeDict,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
) -> None:
|
||||
"""Display the batch statistics and save the batch into disk.
|
||||
|
||||
Args:
|
||||
batch:
|
||||
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||
for the content in it.
|
||||
params:
|
||||
Parameters for training. See :func:`get_params`.
|
||||
sp:
|
||||
The BPE model.
|
||||
"""
|
||||
from lhotse.utils import uuid4
|
||||
|
||||
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
|
||||
logging.info(f"Saving batch to {filename}")
|
||||
torch.save(batch, filename)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
features = batch["inputs"]
|
||||
|
||||
logging.info(f"features shape: {features.shape}")
|
||||
|
||||
y = sp.encode(supervisions["text"], out_type=int)
|
||||
num_tokens = sum(len(i) for i in y)
|
||||
logging.info(f"num tokens: {num_tokens}")
|
||||
|
||||
|
||||
def scan_pessimistic_batches_for_oom(
|
||||
model: nn.Module,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
@ -1016,6 +1109,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
f"Failing criterion: {criterion} "
|
||||
f"(={crit_values[criterion]}) ..."
|
||||
)
|
||||
display_and_save_batch(batch, params=params, sp=sp)
|
||||
raise
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user