change epoch number counter starting from 1 instead of 0

This commit is contained in:
yaozengwei 2022-05-05 11:46:49 +08:00
parent 8eb380d796
commit ff3c0d5d86

View File

@ -21,20 +21,20 @@ Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless2/train.py \
./pruned_transducer_stateless4/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--start-epoch 1 \
--exp-dir pruned_transducer_stateless2/exp \
--full-libri 1 \
--max-duration 300
# For mix precision training:
./pruned_transducer_stateless2/train.py \
./pruned_transducer_stateless4/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir pruned_transducer_stateless2/exp \
--full-libri 1 \
@ -123,7 +123,7 @@ def get_parser():
parser.add_argument(
"--start-epoch",
type=int,
default=0,
default=1,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
transducer_stateless2/exp/epoch-{start_epoch-1}.pt
@ -418,7 +418,7 @@ def load_checkpoint_if_available(
If params.start_batch is positive, it will load the checkpoint from
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
params.start_epoch is positive, it will load the checkpoint from
params.start_epoch is larger than 1, it will load the checkpoint from
`params.start_epoch - 1`.
Apart from loading state dict for `model` and `optimizer` it also updates
@ -430,6 +430,8 @@ def load_checkpoint_if_available(
The return value of :func:`get_params`.
model:
The training model.
model_avg:
The stored model averaged from the start of training.
optimizer:
The optimizer that we are using.
scheduler:
@ -439,7 +441,7 @@ def load_checkpoint_if_available(
"""
if params.start_batch > 0:
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
elif params.start_epoch > 0:
elif params.start_epoch > 1:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else:
return None
@ -849,7 +851,7 @@ def run(rank, world_size, args):
logging.info(f"Number of model parameters: {num_param}")
assert params.save_every_n >= params.average_period
model_avg: nn.Module = None
model_avg: Optional[nn.Module] = None
if rank == 0:
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model)
@ -939,10 +941,10 @@ def run(rank, world_size, args):
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs):
scheduler.step_epoch(epoch)
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
for epoch in range(params.start_epoch, params.num_epochs + 1):
scheduler.step_epoch(epoch - 1)
fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
@ -996,7 +998,7 @@ def scan_pessimistic_batches_for_oom(
from lhotse.dataset import find_pessimistic_batches
logging.info(
"Sanity check -- see if any of the batches in epoch 0 would cause OOM."
"Sanity check -- see if any of the batches in epoch 1 would cause OOM."
)
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items():