diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index c015cddde..0e9fe1423 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -853,6 +853,19 @@ def train_one_epoch( cur_batch_idx = params.get("cur_batch_idx", 0) + saved_bad_model = False + def save_bad_model(suffix: str = ""): + save_checkpoint_impl(filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=0) + + for batch_idx, batch in enumerate(train_dl): if batch_idx % 10 == 0: set_batch_count(model, get_adjusted_batch_count(params)) @@ -884,6 +897,7 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() except: # noqa + save_bad_model() display_and_save_batch(batch, params=params, sp=sp) raise @@ -934,8 +948,12 @@ def train_one_epoch( if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: + save_bad_model() raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}") if batch_idx % params.log_interval == 0: