mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
change epoch number counter starting from 1 instead of 0
This commit is contained in:
parent
8eb380d796
commit
ff3c0d5d86
@ -21,20 +21,20 @@ Usage:
|
|||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
|
||||||
./pruned_transducer_stateless2/train.py \
|
./pruned_transducer_stateless4/train.py \
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 0 \
|
--start-epoch 1 \
|
||||||
--exp-dir pruned_transducer_stateless2/exp \
|
--exp-dir pruned_transducer_stateless2/exp \
|
||||||
--full-libri 1 \
|
--full-libri 1 \
|
||||||
--max-duration 300
|
--max-duration 300
|
||||||
|
|
||||||
# For mix precision training:
|
# For mix precision training:
|
||||||
|
|
||||||
./pruned_transducer_stateless2/train.py \
|
./pruned_transducer_stateless4/train.py \
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 0 \
|
--start-epoch 1 \
|
||||||
--use-fp16 1 \
|
--use-fp16 1 \
|
||||||
--exp-dir pruned_transducer_stateless2/exp \
|
--exp-dir pruned_transducer_stateless2/exp \
|
||||||
--full-libri 1 \
|
--full-libri 1 \
|
||||||
@ -123,7 +123,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--start-epoch",
|
"--start-epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=1,
|
||||||
help="""Resume training from from this epoch.
|
help="""Resume training from from this epoch.
|
||||||
If it is positive, it will load checkpoint from
|
If it is positive, it will load checkpoint from
|
||||||
transducer_stateless2/exp/epoch-{start_epoch-1}.pt
|
transducer_stateless2/exp/epoch-{start_epoch-1}.pt
|
||||||
@ -418,7 +418,7 @@ def load_checkpoint_if_available(
|
|||||||
|
|
||||||
If params.start_batch is positive, it will load the checkpoint from
|
If params.start_batch is positive, it will load the checkpoint from
|
||||||
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
|
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
|
||||||
params.start_epoch is positive, it will load the checkpoint from
|
params.start_epoch is larger than 1, it will load the checkpoint from
|
||||||
`params.start_epoch - 1`.
|
`params.start_epoch - 1`.
|
||||||
|
|
||||||
Apart from loading state dict for `model` and `optimizer` it also updates
|
Apart from loading state dict for `model` and `optimizer` it also updates
|
||||||
@ -430,6 +430,8 @@ def load_checkpoint_if_available(
|
|||||||
The return value of :func:`get_params`.
|
The return value of :func:`get_params`.
|
||||||
model:
|
model:
|
||||||
The training model.
|
The training model.
|
||||||
|
model_avg:
|
||||||
|
The stored model averaged from the start of training.
|
||||||
optimizer:
|
optimizer:
|
||||||
The optimizer that we are using.
|
The optimizer that we are using.
|
||||||
scheduler:
|
scheduler:
|
||||||
@ -439,7 +441,7 @@ def load_checkpoint_if_available(
|
|||||||
"""
|
"""
|
||||||
if params.start_batch > 0:
|
if params.start_batch > 0:
|
||||||
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
|
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
|
||||||
elif params.start_epoch > 0:
|
elif params.start_epoch > 1:
|
||||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
@ -849,7 +851,7 @@ def run(rank, world_size, args):
|
|||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
assert params.save_every_n >= params.average_period
|
assert params.save_every_n >= params.average_period
|
||||||
model_avg: nn.Module = None
|
model_avg: Optional[nn.Module] = None
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
# model_avg is only used with rank 0
|
# model_avg is only used with rank 0
|
||||||
model_avg = copy.deepcopy(model)
|
model_avg = copy.deepcopy(model)
|
||||||
@ -939,10 +941,10 @@ def run(rank, world_size, args):
|
|||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||||
scheduler.step_epoch(epoch)
|
scheduler.step_epoch(epoch - 1)
|
||||||
fix_random_seed(params.seed + epoch)
|
fix_random_seed(params.seed + epoch - 1)
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch - 1)
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||||
@ -996,7 +998,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
from lhotse.dataset import find_pessimistic_batches
|
from lhotse.dataset import find_pessimistic_batches
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
"Sanity check -- see if any of the batches in epoch 0 would cause OOM."
|
"Sanity check -- see if any of the batches in epoch 1 would cause OOM."
|
||||||
)
|
)
|
||||||
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user