This commit is contained in:
Yifan Yang 2023-05-31 11:11:11 +08:00
parent 584a956bf4
commit 2036652598
2 changed files with 99 additions and 29 deletions

View File

@ -129,7 +129,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--tie-weights", "--tie-weights",
type=str2bool, type=str2bool,
default=False, default=True,
help="""True to share the weights between the input embedding layer and the help="""True to share the weights between the input embedding layer and the
last output linear layer last output linear layer
""", """,

View File

@ -18,19 +18,15 @@
""" """
Usage: Usage:
./rnn_lm/train.py \ ./rnn_lm/train.py \
--start-epoch 0 \ --exp-dir rnn_lm/exp \
--world-size 2 \ --start-epoch 1 \
--num-epochs 1 \ --world-size 1 \
--use-fp16 0 \ --num-epochs 30 \
--tie-weights 0 \ --batch-size 150
--embedding-dim 800 \
--hidden-dim 200 \
--num-layers 2 \
--batch-size 400
""" """
import argparse import argparse
import copy
import logging import logging
import math import math
from pathlib import Path from pathlib import Path
@ -50,7 +46,10 @@ from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.checkpoint import (
save_checkpoint_with_global_batch_idx,
update_averaged_model,
)
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@ -92,13 +91,22 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--start-epoch", "--start-epoch",
type=int, type=int,
default=0, default=1,
help="""Resume training from from this epoch. help="""Resume training from from this epoch. It should be positive.
If it is positive, it will load checkpoint from If larger than 1, it will load checkpoint from
exp_dir/epoch-{start_epoch-1}.pt exp_dir/epoch-{start_epoch-1}.pt
""", """,
) )
parser.add_argument(
"--start-batch",
type=int,
default=0,
help="""If positive, --start-epoch is ignored and
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
""",
)
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
@ -112,14 +120,14 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--use-fp16", "--use-fp16",
type=str2bool, type=str2bool,
default=True, default=False,
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
parser.add_argument( parser.add_argument(
"--batch-size", "--batch-size",
type=int, type=int,
default=400, default=150,
) )
parser.add_argument( parser.add_argument(
@ -197,7 +205,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--save-every-n", "--save-every-n",
type=int, type=int,
default=2000, default=20000,
help="""Save checkpoint after processing this number of batches" help="""Save checkpoint after processing this number of batches"
periodically. We save checkpoint to exp-dir/ whenever periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename params.batch_idx_train % save_every_n == 0. The checkpoint filename
@ -207,6 +215,19 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--average-period",
type=int,
default=200,
help="""Update the averaged model, namely `model_avg`, after processing
this number of batches. `model_avg` is a separate version of model,
in which each floating-point parameter is the average of all the
parameters from the start of training. Each time we take the average,
we do: `model_avg = model * (average_period / batch_idx_train) +
model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
""",
)
return parser return parser
@ -225,9 +246,9 @@ def get_params() -> AttributeDict:
"best_train_epoch": -1, "best_train_epoch": -1,
"best_valid_epoch": -1, "best_valid_epoch": -1,
"batch_idx_train": 0, "batch_idx_train": 0,
"log_interval": 100, "log_interval": 200,
"reset_interval": 2000, "reset_interval": 2000,
"valid_interval": 200, "valid_interval": 5000,
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )
@ -237,13 +258,16 @@ def get_params() -> AttributeDict:
def load_checkpoint_if_available( def load_checkpoint_if_available(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
model_avg: nn.Module = None,
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None: ) -> None:
"""Load checkpoint from file. """Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from If params.start_batch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing. `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
params.start_epoch is larger than 1, it will load the checkpoint from
`params.start_epoch - 1`.
Apart from loading state dict for `model`, `optimizer` and `scheduler`, Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
@ -254,6 +278,8 @@ def load_checkpoint_if_available(
The return value of :func:`get_params`. The return value of :func:`get_params`.
model: model:
The training model. The training model.
model_avg:
The stored model averaged from the start of training.
optimizer: optimizer:
The optimizer that we are using. The optimizer that we are using.
scheduler: scheduler:
@ -261,14 +287,20 @@ def load_checkpoint_if_available(
Returns: Returns:
Return None. Return None.
""" """
if params.start_epoch <= 0: if params.start_batch > 0:
return filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
elif params.start_epoch > 1:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else:
return None
assert filename.is_file(), f"{filename} does not exist!"
logging.info(f"Loading checkpoint: {filename}") logging.info(f"Loading checkpoint: {filename}")
saved_params = load_checkpoint( saved_params = load_checkpoint(
filename, filename,
model=model, model=model,
model_avg=model_avg,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, scheduler=scheduler,
) )
@ -283,12 +315,20 @@ def load_checkpoint_if_available(
for k in keys: for k in keys:
params[k] = saved_params[k] params[k] = saved_params[k]
if params.start_batch > 0:
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
def save_checkpoint( def save_checkpoint(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
model_avg: nn.Module = None,
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
rank: int = 0, rank: int = 0,
@ -300,6 +340,8 @@ def save_checkpoint(
It is returned by :func:`get_params`. It is returned by :func:`get_params`.
model: model:
The training model. The training model.
model_avg:
The stored model averaged from the start of training.
""" """
if rank != 0: if rank != 0:
return return
@ -307,6 +349,7 @@ def save_checkpoint(
save_checkpoint_impl( save_checkpoint_impl(
filename=filename, filename=filename,
model=model, model=model,
model_avg=model_avg,
params=params, params=params,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, scheduler=scheduler,
@ -408,6 +451,7 @@ def train_one_epoch(
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
model_avg: nn.Module = None,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
rank: int = 0, rank: int = 0,
@ -423,6 +467,8 @@ def train_one_epoch(
It is returned by :func:`get_params`. It is returned by :func:`get_params`.
model: model:
The model for training. The model for training.
model_avg:
The stored model averaged from the start of training.
optimizer: optimizer:
The optimizer we are using. The optimizer we are using.
train_dl: train_dl:
@ -459,6 +505,18 @@ def train_one_epoch(
clip_grad_norm_(model.parameters(), 5.0, 2.0) clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step() optimizer.step()
if (
rank == 0
and params.batch_idx_train > 0
and params.batch_idx_train % params.average_period == 0
):
update_averaged_model(
params=params,
model_cur=model,
model_avg=model_avg,
)
if ( if (
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
@ -467,6 +525,7 @@ def train_one_epoch(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
model=model, model=model,
model_avg=model_avg,
params=params, params=params,
optimizer=optimizer, optimizer=optimizer,
rank=rank, rank=rank,
@ -576,7 +635,16 @@ def run(rank, world_size, args):
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
checkpoints = load_checkpoint_if_available(params=params, model=model) assert params.save_every_n >= params.average_period
model_avg: Optional[nn.Module] = None
if rank == 0:
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model).to(torch.float64)
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(
params=params, model=model, model_avg=model_avg
)
model.to(device) model.to(device)
if is_distributed: if is_distributed:
@ -608,15 +676,16 @@ def run(rank, world_size, args):
) )
# Note: No learning rate scheduler is used here # Note: No learning rate scheduler is used here
for epoch in range(params.start_epoch, params.num_epochs): for epoch in range(params.start_epoch, params.num_epochs + 1):
if is_distributed: if is_distributed:
train_dl.sampler.set_epoch(epoch) train_dl.sampler.set_epoch(epoch - 1)
params.cur_epoch = epoch params.cur_epoch = epoch
train_one_epoch( train_one_epoch(
params=params, params=params,
model=model, model=model,
model_avg=model_avg,
optimizer=optimizer, optimizer=optimizer,
train_dl=train_dl, train_dl=train_dl,
valid_dl=valid_dl, valid_dl=valid_dl,
@ -628,6 +697,7 @@ def run(rank, world_size, args):
save_checkpoint( save_checkpoint(
params=params, params=params,
model=model, model=model,
model_avg=model_avg,
optimizer=optimizer, optimizer=optimizer,
rank=rank, rank=rank,
) )