diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 33478c630..a08ebad54 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -21,20 +21,20 @@ Usage: export CUDA_VISIBLE_DEVICES="0,1,2,3" -./pruned_transducer_stateless2/train.py \ +./pruned_transducer_stateless4/train.py \ --world-size 4 \ --num-epochs 30 \ - --start-epoch 0 \ + --start-epoch 1 \ --exp-dir pruned_transducer_stateless2/exp \ --full-libri 1 \ --max-duration 300 # For mix precision training: -./pruned_transducer_stateless2/train.py \ +./pruned_transducer_stateless4/train.py \ --world-size 4 \ --num-epochs 30 \ - --start-epoch 0 \ + --start-epoch 1 \ --use-fp16 1 \ --exp-dir pruned_transducer_stateless2/exp \ --full-libri 1 \ @@ -123,7 +123,7 @@ def get_parser(): parser.add_argument( "--start-epoch", type=int, - default=0, + default=1, help="""Resume training from from this epoch. If it is positive, it will load checkpoint from transducer_stateless2/exp/epoch-{start_epoch-1}.pt @@ -418,7 +418,7 @@ def load_checkpoint_if_available( If params.start_batch is positive, it will load the checkpoint from `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is positive, it will load the checkpoint from + params.start_epoch is larger than 1, it will load the checkpoint from `params.start_epoch - 1`. Apart from loading state dict for `model` and `optimizer` it also updates @@ -430,6 +430,8 @@ def load_checkpoint_if_available( The return value of :func:`get_params`. model: The training model. + model_avg: + The stored model averaged from the start of training. optimizer: The optimizer that we are using. scheduler: @@ -439,7 +441,7 @@ def load_checkpoint_if_available( """ if params.start_batch > 0: filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 0: + elif params.start_epoch > 1: filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" else: return None @@ -849,7 +851,7 @@ def run(rank, world_size, args): logging.info(f"Number of model parameters: {num_param}") assert params.save_every_n >= params.average_period - model_avg: nn.Module = None + model_avg: Optional[nn.Module] = None if rank == 0: # model_avg is only used with rank 0 model_avg = copy.deepcopy(model) @@ -939,10 +941,10 @@ def run(rank, world_size, args): logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) - for epoch in range(params.start_epoch, params.num_epochs): - scheduler.step_epoch(epoch) - fix_random_seed(params.seed + epoch) - train_dl.sampler.set_epoch(epoch) + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) if tb_writer is not None: tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) @@ -996,7 +998,7 @@ def scan_pessimistic_batches_for_oom( from lhotse.dataset import find_pessimistic_batches logging.info( - "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." ) batches, crit_values = find_pessimistic_batches(train_dl.sampler) for criterion, cuts in batches.items():