Save checkpoint on failure.
This commit is contained in:
parent
96d167a2ec
commit
266e71cc79
@ -853,6 +853,19 @@ def train_one_epoch(
|
|||||||
|
|
||||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||||
|
|
||||||
|
saved_bad_model = False
|
||||||
|
def save_bad_model(suffix: str = ""):
|
||||||
|
save_checkpoint_impl(filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
||||||
|
model=model,
|
||||||
|
model_avg=model_avg,
|
||||||
|
params=params,
|
||||||
|
optimizer=optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
sampler=sampler,
|
||||||
|
scaler=scaler,
|
||||||
|
rank=0)
|
||||||
|
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
if batch_idx % 10 == 0:
|
if batch_idx % 10 == 0:
|
||||||
set_batch_count(model, get_adjusted_batch_count(params))
|
set_batch_count(model, get_adjusted_batch_count(params))
|
||||||
@ -884,6 +897,7 @@ def train_one_epoch(
|
|||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
except: # noqa
|
except: # noqa
|
||||||
|
save_bad_model()
|
||||||
display_and_save_batch(batch, params=params, sp=sp)
|
display_and_save_batch(batch, params=params, sp=sp)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@ -934,8 +948,12 @@ def train_one_epoch(
|
|||||||
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
|
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
|
||||||
scaler.update(cur_grad_scale * 2.0)
|
scaler.update(cur_grad_scale * 2.0)
|
||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
|
if not saved_bad_model:
|
||||||
|
save_bad_model(suffix="-first-warning")
|
||||||
|
saved_bad_model = True
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
if cur_grad_scale < 1.0e-05:
|
if cur_grad_scale < 1.0e-05:
|
||||||
|
save_bad_model()
|
||||||
raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}")
|
raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}")
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user