Save checkpoint on failure.

This commit is contained in:
Daniel Povey 2022-12-21 13:42:16 +08:00
parent 96d167a2ec
commit 266e71cc79

View File

@ -853,6 +853,19 @@ def train_one_epoch(
cur_batch_idx = params.get("cur_batch_idx", 0)
saved_bad_model = False
def save_bad_model(suffix: str = ""):
save_checkpoint_impl(filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
model=model,
model_avg=model_avg,
params=params,
optimizer=optimizer,
scheduler=scheduler,
sampler=sampler,
scaler=scaler,
rank=0)
for batch_idx, batch in enumerate(train_dl):
if batch_idx % 10 == 0:
set_batch_count(model, get_adjusted_batch_count(params))
@ -884,6 +897,7 @@ def train_one_epoch(
scaler.update()
optimizer.zero_grad()
except: # noqa
save_bad_model()
display_and_save_batch(batch, params=params, sp=sp)
raise
@ -934,8 +948,12 @@ def train_one_epoch(
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
if not saved_bad_model:
save_bad_model(suffix="-first-warning")
saved_bad_model = True
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}")
if batch_idx % params.log_interval == 0: