mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
update
This commit is contained in:
parent
584a956bf4
commit
2036652598
@ -129,7 +129,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--tie-weights",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
default=True,
|
||||
help="""True to share the weights between the input embedding layer and the
|
||||
last output linear layer
|
||||
""",
|
||||
|
||||
@ -17,20 +17,16 @@
|
||||
|
||||
"""
|
||||
Usage:
|
||||
./rnn_lm/train.py \
|
||||
--start-epoch 0 \
|
||||
--world-size 2 \
|
||||
--num-epochs 1 \
|
||||
--use-fp16 0 \
|
||||
--tie-weights 0 \
|
||||
--embedding-dim 800 \
|
||||
--hidden-dim 200 \
|
||||
--num-layers 2 \
|
||||
--batch-size 400
|
||||
|
||||
./rnn_lm/train.py \
|
||||
--exp-dir rnn_lm/exp \
|
||||
--start-epoch 1 \
|
||||
--world-size 1 \
|
||||
--num-epochs 30 \
|
||||
--batch-size 150
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
@ -50,7 +46,10 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
||||
from icefall.checkpoint import (
|
||||
save_checkpoint_with_global_batch_idx,
|
||||
update_averaged_model,
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
@ -92,13 +91,22 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""Resume training from from this epoch.
|
||||
If it is positive, it will load checkpoint from
|
||||
default=1,
|
||||
help="""Resume training from from this epoch. It should be positive.
|
||||
If larger than 1, it will load checkpoint from
|
||||
exp_dir/epoch-{start_epoch-1}.pt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start-batch",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --start-epoch is ignored and
|
||||
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
@ -112,14 +120,14 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--use-fp16",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
default=False,
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=400,
|
||||
default=150,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -197,7 +205,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--save-every-n",
|
||||
type=int,
|
||||
default=2000,
|
||||
default=20000,
|
||||
help="""Save checkpoint after processing this number of batches"
|
||||
periodically. We save checkpoint to exp-dir/ whenever
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
@ -207,6 +215,19 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--average-period",
|
||||
type=int,
|
||||
default=200,
|
||||
help="""Update the averaged model, namely `model_avg`, after processing
|
||||
this number of batches. `model_avg` is a separate version of model,
|
||||
in which each floating-point parameter is the average of all the
|
||||
parameters from the start of training. Each time we take the average,
|
||||
we do: `model_avg = model * (average_period / batch_idx_train) +
|
||||
model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -225,9 +246,9 @@ def get_params() -> AttributeDict:
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 100,
|
||||
"log_interval": 200,
|
||||
"reset_interval": 2000,
|
||||
"valid_interval": 200,
|
||||
"valid_interval": 5000,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
@ -237,13 +258,16 @@ def get_params() -> AttributeDict:
|
||||
def load_checkpoint_if_available(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
model_avg: nn.Module = None,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||
) -> None:
|
||||
"""Load checkpoint from file.
|
||||
|
||||
If params.start_epoch is positive, it will load the checkpoint from
|
||||
`params.start_epoch - 1`. Otherwise, this function does nothing.
|
||||
If params.start_batch is positive, it will load the checkpoint from
|
||||
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
|
||||
params.start_epoch is larger than 1, it will load the checkpoint from
|
||||
`params.start_epoch - 1`.
|
||||
|
||||
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
|
||||
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
||||
@ -254,6 +278,8 @@ def load_checkpoint_if_available(
|
||||
The return value of :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
model_avg:
|
||||
The stored model averaged from the start of training.
|
||||
optimizer:
|
||||
The optimizer that we are using.
|
||||
scheduler:
|
||||
@ -261,14 +287,20 @@ def load_checkpoint_if_available(
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
if params.start_epoch <= 0:
|
||||
return
|
||||
if params.start_batch > 0:
|
||||
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
|
||||
elif params.start_epoch > 1:
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
else:
|
||||
return None
|
||||
|
||||
assert filename.is_file(), f"{filename} does not exist!"
|
||||
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
logging.info(f"Loading checkpoint: {filename}")
|
||||
saved_params = load_checkpoint(
|
||||
filename,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
@ -283,12 +315,20 @@ def load_checkpoint_if_available(
|
||||
for k in keys:
|
||||
params[k] = saved_params[k]
|
||||
|
||||
if params.start_batch > 0:
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
model_avg: nn.Module = None,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||
rank: int = 0,
|
||||
@ -300,6 +340,8 @@ def save_checkpoint(
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
model_avg:
|
||||
The stored model averaged from the start of training.
|
||||
"""
|
||||
if rank != 0:
|
||||
return
|
||||
@ -307,6 +349,7 @@ def save_checkpoint(
|
||||
save_checkpoint_impl(
|
||||
filename=filename,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
@ -408,6 +451,7 @@ def train_one_epoch(
|
||||
optimizer: torch.optim.Optimizer,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
model_avg: nn.Module = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
@ -423,6 +467,8 @@ def train_one_epoch(
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The model for training.
|
||||
model_avg:
|
||||
The stored model averaged from the start of training.
|
||||
optimizer:
|
||||
The optimizer we are using.
|
||||
train_dl:
|
||||
@ -459,6 +505,18 @@ def train_one_epoch(
|
||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||
optimizer.step()
|
||||
|
||||
if (
|
||||
rank == 0
|
||||
and params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.average_period == 0
|
||||
):
|
||||
update_averaged_model(
|
||||
params=params,
|
||||
model_cur=model,
|
||||
model_avg=model_avg,
|
||||
)
|
||||
|
||||
|
||||
if (
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
@ -467,6 +525,7 @@ def train_one_epoch(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
rank=rank,
|
||||
@ -576,7 +635,16 @@ def run(rank, world_size, args):
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||
assert params.save_every_n >= params.average_period
|
||||
model_avg: Optional[nn.Module] = None
|
||||
if rank == 0:
|
||||
# model_avg is only used with rank 0
|
||||
model_avg = copy.deepcopy(model).to(torch.float64)
|
||||
|
||||
assert params.start_epoch > 0, params.start_epoch
|
||||
checkpoints = load_checkpoint_if_available(
|
||||
params=params, model=model, model_avg=model_avg
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
if is_distributed:
|
||||
@ -608,15 +676,16 @@ def run(rank, world_size, args):
|
||||
)
|
||||
|
||||
# Note: No learning rate scheduler is used here
|
||||
for epoch in range(params.start_epoch, params.num_epochs):
|
||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||
if is_distributed:
|
||||
train_dl.sampler.set_epoch(epoch)
|
||||
train_dl.sampler.set_epoch(epoch - 1)
|
||||
|
||||
params.cur_epoch = epoch
|
||||
|
||||
train_one_epoch(
|
||||
params=params,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
@ -628,6 +697,7 @@ def run(rank, world_size, args):
|
||||
save_checkpoint(
|
||||
params=params,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user