diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index e71f0d1c6..0f51b4382 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -395,9 +395,10 @@ def load_checkpoint_if_available( "cur_batch_idx", ] for k in keys: - params[k] = saved_params[k] + params[k] = saved_params.get(k, 0) - params["start_epoch"] = saved_params["cur_epoch"] + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] return saved_params