mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
update
This commit is contained in:
parent
584a956bf4
commit
2036652598
@ -129,7 +129,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tie-weights",
|
"--tie-weights",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=True,
|
||||||
help="""True to share the weights between the input embedding layer and the
|
help="""True to share the weights between the input embedding layer and the
|
||||||
last output linear layer
|
last output linear layer
|
||||||
""",
|
""",
|
||||||
|
|||||||
@ -17,20 +17,16 @@
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
./rnn_lm/train.py \
|
./rnn_lm/train.py \
|
||||||
--start-epoch 0 \
|
--exp-dir rnn_lm/exp \
|
||||||
--world-size 2 \
|
--start-epoch 1 \
|
||||||
--num-epochs 1 \
|
--world-size 1 \
|
||||||
--use-fp16 0 \
|
--num-epochs 30 \
|
||||||
--tie-weights 0 \
|
--batch-size 150
|
||||||
--embedding-dim 800 \
|
|
||||||
--hidden-dim 200 \
|
|
||||||
--num-layers 2 \
|
|
||||||
--batch-size 400
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -50,7 +46,10 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
|
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
from icefall.checkpoint import (
|
||||||
|
save_checkpoint_with_global_batch_idx,
|
||||||
|
update_averaged_model,
|
||||||
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
@ -92,13 +91,22 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--start-epoch",
|
"--start-epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=1,
|
||||||
help="""Resume training from from this epoch.
|
help="""Resume training from from this epoch. It should be positive.
|
||||||
If it is positive, it will load checkpoint from
|
If larger than 1, it will load checkpoint from
|
||||||
exp_dir/epoch-{start_epoch-1}.pt
|
exp_dir/epoch-{start_epoch-1}.pt
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--start-batch",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --start-epoch is ignored and
|
||||||
|
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
@ -112,14 +120,14 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-fp16",
|
"--use-fp16",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=False,
|
||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--batch-size",
|
"--batch-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=400,
|
default=150,
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -197,7 +205,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save-every-n",
|
"--save-every-n",
|
||||||
type=int,
|
type=int,
|
||||||
default=2000,
|
default=20000,
|
||||||
help="""Save checkpoint after processing this number of batches"
|
help="""Save checkpoint after processing this number of batches"
|
||||||
periodically. We save checkpoint to exp-dir/ whenever
|
periodically. We save checkpoint to exp-dir/ whenever
|
||||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||||
@ -207,6 +215,19 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--average-period",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help="""Update the averaged model, namely `model_avg`, after processing
|
||||||
|
this number of batches. `model_avg` is a separate version of model,
|
||||||
|
in which each floating-point parameter is the average of all the
|
||||||
|
parameters from the start of training. Each time we take the average,
|
||||||
|
we do: `model_avg = model * (average_period / batch_idx_train) +
|
||||||
|
model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -225,9 +246,9 @@ def get_params() -> AttributeDict:
|
|||||||
"best_train_epoch": -1,
|
"best_train_epoch": -1,
|
||||||
"best_valid_epoch": -1,
|
"best_valid_epoch": -1,
|
||||||
"batch_idx_train": 0,
|
"batch_idx_train": 0,
|
||||||
"log_interval": 100,
|
"log_interval": 200,
|
||||||
"reset_interval": 2000,
|
"reset_interval": 2000,
|
||||||
"valid_interval": 200,
|
"valid_interval": 5000,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -237,13 +258,16 @@ def get_params() -> AttributeDict:
|
|||||||
def load_checkpoint_if_available(
|
def load_checkpoint_if_available(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
model_avg: nn.Module = None,
|
||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Load checkpoint from file.
|
"""Load checkpoint from file.
|
||||||
|
|
||||||
If params.start_epoch is positive, it will load the checkpoint from
|
If params.start_batch is positive, it will load the checkpoint from
|
||||||
`params.start_epoch - 1`. Otherwise, this function does nothing.
|
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
|
||||||
|
params.start_epoch is larger than 1, it will load the checkpoint from
|
||||||
|
`params.start_epoch - 1`.
|
||||||
|
|
||||||
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
|
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
|
||||||
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
||||||
@ -254,6 +278,8 @@ def load_checkpoint_if_available(
|
|||||||
The return value of :func:`get_params`.
|
The return value of :func:`get_params`.
|
||||||
model:
|
model:
|
||||||
The training model.
|
The training model.
|
||||||
|
model_avg:
|
||||||
|
The stored model averaged from the start of training.
|
||||||
optimizer:
|
optimizer:
|
||||||
The optimizer that we are using.
|
The optimizer that we are using.
|
||||||
scheduler:
|
scheduler:
|
||||||
@ -261,14 +287,20 @@ def load_checkpoint_if_available(
|
|||||||
Returns:
|
Returns:
|
||||||
Return None.
|
Return None.
|
||||||
"""
|
"""
|
||||||
if params.start_epoch <= 0:
|
if params.start_batch > 0:
|
||||||
return
|
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
|
||||||
|
elif params.start_epoch > 1:
|
||||||
|
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
assert filename.is_file(), f"{filename} does not exist!"
|
||||||
|
|
||||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
|
||||||
logging.info(f"Loading checkpoint: {filename}")
|
logging.info(f"Loading checkpoint: {filename}")
|
||||||
saved_params = load_checkpoint(
|
saved_params = load_checkpoint(
|
||||||
filename,
|
filename,
|
||||||
model=model,
|
model=model,
|
||||||
|
model_avg=model_avg,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
)
|
)
|
||||||
@ -283,12 +315,20 @@ def load_checkpoint_if_available(
|
|||||||
for k in keys:
|
for k in keys:
|
||||||
params[k] = saved_params[k]
|
params[k] = saved_params[k]
|
||||||
|
|
||||||
|
if params.start_batch > 0:
|
||||||
|
if "cur_epoch" in saved_params:
|
||||||
|
params["start_epoch"] = saved_params["cur_epoch"]
|
||||||
|
|
||||||
|
if "cur_batch_idx" in saved_params:
|
||||||
|
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||||
|
|
||||||
return saved_params
|
return saved_params
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(
|
def save_checkpoint(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
model_avg: nn.Module = None,
|
||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
@ -300,6 +340,8 @@ def save_checkpoint(
|
|||||||
It is returned by :func:`get_params`.
|
It is returned by :func:`get_params`.
|
||||||
model:
|
model:
|
||||||
The training model.
|
The training model.
|
||||||
|
model_avg:
|
||||||
|
The stored model averaged from the start of training.
|
||||||
"""
|
"""
|
||||||
if rank != 0:
|
if rank != 0:
|
||||||
return
|
return
|
||||||
@ -307,6 +349,7 @@ def save_checkpoint(
|
|||||||
save_checkpoint_impl(
|
save_checkpoint_impl(
|
||||||
filename=filename,
|
filename=filename,
|
||||||
model=model,
|
model=model,
|
||||||
|
model_avg=model_avg,
|
||||||
params=params,
|
params=params,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
@ -408,6 +451,7 @@ def train_one_epoch(
|
|||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
|
model_avg: nn.Module = None,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
@ -423,6 +467,8 @@ def train_one_epoch(
|
|||||||
It is returned by :func:`get_params`.
|
It is returned by :func:`get_params`.
|
||||||
model:
|
model:
|
||||||
The model for training.
|
The model for training.
|
||||||
|
model_avg:
|
||||||
|
The stored model averaged from the start of training.
|
||||||
optimizer:
|
optimizer:
|
||||||
The optimizer we are using.
|
The optimizer we are using.
|
||||||
train_dl:
|
train_dl:
|
||||||
@ -459,6 +505,18 @@ def train_one_epoch(
|
|||||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
if (
|
||||||
|
rank == 0
|
||||||
|
and params.batch_idx_train > 0
|
||||||
|
and params.batch_idx_train % params.average_period == 0
|
||||||
|
):
|
||||||
|
update_averaged_model(
|
||||||
|
params=params,
|
||||||
|
model_cur=model,
|
||||||
|
model_avg=model_avg,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
params.batch_idx_train > 0
|
params.batch_idx_train > 0
|
||||||
and params.batch_idx_train % params.save_every_n == 0
|
and params.batch_idx_train % params.save_every_n == 0
|
||||||
@ -467,6 +525,7 @@ def train_one_epoch(
|
|||||||
out_dir=params.exp_dir,
|
out_dir=params.exp_dir,
|
||||||
global_batch_idx=params.batch_idx_train,
|
global_batch_idx=params.batch_idx_train,
|
||||||
model=model,
|
model=model,
|
||||||
|
model_avg=model_avg,
|
||||||
params=params,
|
params=params,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
@ -576,7 +635,16 @@ def run(rank, world_size, args):
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
assert params.save_every_n >= params.average_period
|
||||||
|
model_avg: Optional[nn.Module] = None
|
||||||
|
if rank == 0:
|
||||||
|
# model_avg is only used with rank 0
|
||||||
|
model_avg = copy.deepcopy(model).to(torch.float64)
|
||||||
|
|
||||||
|
assert params.start_epoch > 0, params.start_epoch
|
||||||
|
checkpoints = load_checkpoint_if_available(
|
||||||
|
params=params, model=model, model_avg=model_avg
|
||||||
|
)
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
if is_distributed:
|
if is_distributed:
|
||||||
@ -608,15 +676,16 @@ def run(rank, world_size, args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Note: No learning rate scheduler is used here
|
# Note: No learning rate scheduler is used here
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||||
if is_distributed:
|
if is_distributed:
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch - 1)
|
||||||
|
|
||||||
params.cur_epoch = epoch
|
params.cur_epoch = epoch
|
||||||
|
|
||||||
train_one_epoch(
|
train_one_epoch(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
model_avg=model_avg,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
@ -628,6 +697,7 @@ def run(rank, world_size, args):
|
|||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
model_avg=model_avg,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user