From 2036652598e8ab7ae2c15a28c29ed323df2c7e50 Mon Sep 17 00:00:00 2001 From: Yifan Yang Date: Wed, 31 May 2023 11:11:11 +0800 Subject: [PATCH] update --- icefall/rnn_lm/compute_perplexity.py | 2 +- icefall/rnn_lm/train.py | 126 +++++++++++++++++++++------ 2 files changed, 99 insertions(+), 29 deletions(-) diff --git a/icefall/rnn_lm/compute_perplexity.py b/icefall/rnn_lm/compute_perplexity.py index a9a8d29d0..80488d5b0 100755 --- a/icefall/rnn_lm/compute_perplexity.py +++ b/icefall/rnn_lm/compute_perplexity.py @@ -129,7 +129,7 @@ def get_parser(): parser.add_argument( "--tie-weights", type=str2bool, - default=False, + default=True, help="""True to share the weights between the input embedding layer and the last output linear layer """, diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py index 0f0887859..86bf9b7e0 100755 --- a/icefall/rnn_lm/train.py +++ b/icefall/rnn_lm/train.py @@ -17,20 +17,16 @@ """ Usage: - ./rnn_lm/train.py \ - --start-epoch 0 \ - --world-size 2 \ - --num-epochs 1 \ - --use-fp16 0 \ - --tie-weights 0 \ - --embedding-dim 800 \ - --hidden-dim 200 \ - --num-layers 2 \ - --batch-size 400 - +./rnn_lm/train.py \ + --exp-dir rnn_lm/exp \ + --start-epoch 1 \ + --world-size 1 \ + --num-epochs 30 \ + --batch-size 150 """ import argparse +import copy import logging import math from pathlib import Path @@ -50,7 +46,10 @@ from torch.utils.tensorboard import SummaryWriter from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import save_checkpoint_with_global_batch_idx +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool @@ -92,13 +91,22 @@ def get_parser(): parser.add_argument( "--start-epoch", type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from + default=1, + help="""Resume training from from this epoch. It should be positive. + If larger than 1, it will load checkpoint from exp_dir/epoch-{start_epoch-1}.pt """, ) + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + parser.add_argument( "--exp-dir", type=str, @@ -112,14 +120,14 @@ def get_parser(): parser.add_argument( "--use-fp16", type=str2bool, - default=True, + default=False, help="Whether to use half precision training.", ) parser.add_argument( "--batch-size", type=int, - default=400, + default=150, ) parser.add_argument( @@ -197,7 +205,7 @@ def get_parser(): parser.add_argument( "--save-every-n", type=int, - default=2000, + default=20000, help="""Save checkpoint after processing this number of batches" periodically. We save checkpoint to exp-dir/ whenever params.batch_idx_train % save_every_n == 0. The checkpoint filename @@ -207,6 +215,19 @@ def get_parser(): """, ) + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + return parser @@ -225,9 +246,9 @@ def get_params() -> AttributeDict: "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": 0, - "log_interval": 100, + "log_interval": 200, "reset_interval": 2000, - "valid_interval": 200, + "valid_interval": 5000, "env_info": get_env_info(), } ) @@ -237,13 +258,16 @@ def get_params() -> AttributeDict: def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, + model_avg: nn.Module = None, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, ) -> None: """Load checkpoint from file. - If params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. Otherwise, this function does nothing. + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. Apart from loading state dict for `model`, `optimizer` and `scheduler`, it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, @@ -254,6 +278,8 @@ def load_checkpoint_if_available( The return value of :func:`get_params`. model: The training model. + model_avg: + The stored model averaged from the start of training. optimizer: The optimizer that we are using. scheduler: @@ -261,14 +287,20 @@ def load_checkpoint_if_available( Returns: Return None. """ - if params.start_epoch <= 0: - return + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" logging.info(f"Loading checkpoint: {filename}") saved_params = load_checkpoint( filename, model=model, + model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, ) @@ -283,12 +315,20 @@ def load_checkpoint_if_available( for k in keys: params[k] = saved_params[k] + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + return saved_params def save_checkpoint( params: AttributeDict, model: nn.Module, + model_avg: nn.Module = None, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, rank: int = 0, @@ -300,6 +340,8 @@ def save_checkpoint( It is returned by :func:`get_params`. model: The training model. + model_avg: + The stored model averaged from the start of training. """ if rank != 0: return @@ -307,6 +349,7 @@ def save_checkpoint( save_checkpoint_impl( filename=filename, model=model, + model_avg=model_avg, params=params, optimizer=optimizer, scheduler=scheduler, @@ -408,6 +451,7 @@ def train_one_epoch( optimizer: torch.optim.Optimizer, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, + model_avg: nn.Module = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, @@ -423,6 +467,8 @@ def train_one_epoch( It is returned by :func:`get_params`. model: The model for training. + model_avg: + The stored model averaged from the start of training. optimizer: The optimizer we are using. train_dl: @@ -459,6 +505,18 @@ def train_one_epoch( clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( params.batch_idx_train > 0 and params.batch_idx_train % params.save_every_n == 0 @@ -467,6 +525,7 @@ def train_one_epoch( out_dir=params.exp_dir, global_batch_idx=params.batch_idx_train, model=model, + model_avg=model_avg, params=params, optimizer=optimizer, rank=rank, @@ -576,7 +635,16 @@ def run(rank, world_size, args): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoints = load_checkpoint_if_available(params=params, model=model) + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) model.to(device) if is_distributed: @@ -608,15 +676,16 @@ def run(rank, world_size, args): ) # Note: No learning rate scheduler is used here - for epoch in range(params.start_epoch, params.num_epochs): + for epoch in range(params.start_epoch, params.num_epochs + 1): if is_distributed: - train_dl.sampler.set_epoch(epoch) + train_dl.sampler.set_epoch(epoch - 1) params.cur_epoch = epoch train_one_epoch( params=params, model=model, + model_avg=model_avg, optimizer=optimizer, train_dl=train_dl, valid_dl=valid_dl, @@ -628,6 +697,7 @@ def run(rank, world_size, args): save_checkpoint( params=params, model=model, + model_avg=model_avg, optimizer=optimizer, rank=rank, )