This commit is contained in:
Yifan Yang 2023-05-31 11:11:11 +08:00
parent 584a956bf4
commit 2036652598
2 changed files with 99 additions and 29 deletions

View File

@ -129,7 +129,7 @@ def get_parser():
parser.add_argument(
"--tie-weights",
type=str2bool,
default=False,
default=True,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",

View File

@ -17,20 +17,16 @@
"""
Usage:
./rnn_lm/train.py \
--start-epoch 0 \
--world-size 2 \
--num-epochs 1 \
--use-fp16 0 \
--tie-weights 0 \
--embedding-dim 800 \
--hidden-dim 200 \
--num-layers 2 \
--batch-size 400
./rnn_lm/train.py \
--exp-dir rnn_lm/exp \
--start-epoch 1 \
--world-size 1 \
--num-epochs 30 \
--batch-size 150
"""
import argparse
import copy
import logging
import math
from pathlib import Path
@ -50,7 +46,10 @@ from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.checkpoint import (
save_checkpoint_with_global_batch_idx,
update_averaged_model,
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@ -92,13 +91,22 @@ def get_parser():
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
default=1,
help="""Resume training from from this epoch. It should be positive.
If larger than 1, it will load checkpoint from
exp_dir/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--start-batch",
type=int,
default=0,
help="""If positive, --start-epoch is ignored and
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
@ -112,14 +120,14 @@ def get_parser():
parser.add_argument(
"--use-fp16",
type=str2bool,
default=True,
default=False,
help="Whether to use half precision training.",
)
parser.add_argument(
"--batch-size",
type=int,
default=400,
default=150,
)
parser.add_argument(
@ -197,7 +205,7 @@ def get_parser():
parser.add_argument(
"--save-every-n",
type=int,
default=2000,
default=20000,
help="""Save checkpoint after processing this number of batches"
periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename
@ -207,6 +215,19 @@ def get_parser():
""",
)
parser.add_argument(
"--average-period",
type=int,
default=200,
help="""Update the averaged model, namely `model_avg`, after processing
this number of batches. `model_avg` is a separate version of model,
in which each floating-point parameter is the average of all the
parameters from the start of training. Each time we take the average,
we do: `model_avg = model * (average_period / batch_idx_train) +
model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
""",
)
return parser
@ -225,9 +246,9 @@ def get_params() -> AttributeDict:
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 100,
"log_interval": 200,
"reset_interval": 2000,
"valid_interval": 200,
"valid_interval": 5000,
"env_info": get_env_info(),
}
)
@ -237,13 +258,16 @@ def get_params() -> AttributeDict:
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
model_avg: nn.Module = None,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
If params.start_batch is positive, it will load the checkpoint from
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
params.start_epoch is larger than 1, it will load the checkpoint from
`params.start_epoch - 1`.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
@ -254,6 +278,8 @@ def load_checkpoint_if_available(
The return value of :func:`get_params`.
model:
The training model.
model_avg:
The stored model averaged from the start of training.
optimizer:
The optimizer that we are using.
scheduler:
@ -261,14 +287,20 @@ def load_checkpoint_if_available(
Returns:
Return None.
"""
if params.start_epoch <= 0:
return
if params.start_batch > 0:
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
elif params.start_epoch > 1:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else:
return None
assert filename.is_file(), f"{filename} does not exist!"
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
logging.info(f"Loading checkpoint: {filename}")
saved_params = load_checkpoint(
filename,
model=model,
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
)
@ -283,12 +315,20 @@ def load_checkpoint_if_available(
for k in keys:
params[k] = saved_params[k]
if params.start_batch > 0:
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
model_avg: nn.Module = None,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
rank: int = 0,
@ -300,6 +340,8 @@ def save_checkpoint(
It is returned by :func:`get_params`.
model:
The training model.
model_avg:
The stored model averaged from the start of training.
"""
if rank != 0:
return
@ -307,6 +349,7 @@ def save_checkpoint(
save_checkpoint_impl(
filename=filename,
model=model,
model_avg=model_avg,
params=params,
optimizer=optimizer,
scheduler=scheduler,
@ -408,6 +451,7 @@ def train_one_epoch(
optimizer: torch.optim.Optimizer,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
model_avg: nn.Module = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
@ -423,6 +467,8 @@ def train_one_epoch(
It is returned by :func:`get_params`.
model:
The model for training.
model_avg:
The stored model averaged from the start of training.
optimizer:
The optimizer we are using.
train_dl:
@ -459,6 +505,18 @@ def train_one_epoch(
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if (
rank == 0
and params.batch_idx_train > 0
and params.batch_idx_train % params.average_period == 0
):
update_averaged_model(
params=params,
model_cur=model,
model_avg=model_avg,
)
if (
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
@ -467,6 +525,7 @@ def train_one_epoch(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
model=model,
model_avg=model_avg,
params=params,
optimizer=optimizer,
rank=rank,
@ -576,7 +635,16 @@ def run(rank, world_size, args):
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoints = load_checkpoint_if_available(params=params, model=model)
assert params.save_every_n >= params.average_period
model_avg: Optional[nn.Module] = None
if rank == 0:
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model).to(torch.float64)
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(
params=params, model=model, model_avg=model_avg
)
model.to(device)
if is_distributed:
@ -608,15 +676,16 @@ def run(rank, world_size, args):
)
# Note: No learning rate scheduler is used here
for epoch in range(params.start_epoch, params.num_epochs):
for epoch in range(params.start_epoch, params.num_epochs + 1):
if is_distributed:
train_dl.sampler.set_epoch(epoch)
train_dl.sampler.set_epoch(epoch - 1)
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
model_avg=model_avg,
optimizer=optimizer,
train_dl=train_dl,
valid_dl=valid_dl,
@ -628,6 +697,7 @@ def run(rank, world_size, args):
save_checkpoint(
params=params,
model=model,
model_avg=model_avg,
optimizer=optimizer,
rank=rank,
)