mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Save checkpoint on failure.
This commit is contained in:
parent
96d167a2ec
commit
266e71cc79
@ -853,6 +853,19 @@ def train_one_epoch(
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
saved_bad_model = False
|
||||
def save_bad_model(suffix: str = ""):
|
||||
save_checkpoint_impl(filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sampler=sampler,
|
||||
scaler=scaler,
|
||||
rank=0)
|
||||
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx % 10 == 0:
|
||||
set_batch_count(model, get_adjusted_batch_count(params))
|
||||
@ -884,6 +897,7 @@ def train_one_epoch(
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
except: # noqa
|
||||
save_bad_model()
|
||||
display_and_save_batch(batch, params=params, sp=sp)
|
||||
raise
|
||||
|
||||
@ -934,8 +948,12 @@ def train_one_epoch(
|
||||
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
|
||||
scaler.update(cur_grad_scale * 2.0)
|
||||
if cur_grad_scale < 0.01:
|
||||
if not saved_bad_model:
|
||||
save_bad_model(suffix="-first-warning")
|
||||
saved_bad_model = True
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}")
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user