diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py index 6ff4f16ec..80f8c238f 100755 --- a/icefall/rnn_lm/train.py +++ b/icefall/rnn_lm/train.py @@ -278,7 +278,7 @@ def load_checkpoint_if_available( elif params.start_epoch > 1: filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" else: - return + return None logging.info(f"Loading checkpoint: {filename}") saved_params = load_checkpoint( @@ -298,6 +298,10 @@ def load_checkpoint_if_available( for k in keys: params[k] = saved_params[k] + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + return saved_params