From d1e4ae788dcddbefd3840c3f5bbc598ec7e225b9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 10 Apr 2022 15:25:27 +0800 Subject: [PATCH] Refactor how learning rate is set. --- .../ASR/pruned_transducer_stateless2/optim.py | 151 +++++++++++++++++- .../ASR/pruned_transducer_stateless2/train.py | 43 ++--- icefall/checkpoint.py | 11 +- 3 files changed, 174 insertions(+), 31 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index e47c08657..4f7392d3a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -16,7 +16,7 @@ import random -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import torch from torch import Tensor @@ -141,3 +141,152 @@ class Eve(Optimizer): p.addcdiv_(exp_avg, denom, value=-step_size) return loss + +class LRScheduler(object): + """ + Base-class for learning rate schedulers where the learning-rate depends on both the + batch and the epoch. + """ + def __init__(self, optimizer: Optimizer, verbose: bool = False): + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError('{} is not an Optimizer'.format( + type(optimizer).__name__)) + self.optimizer = optimizer + self.verbose = verbose + + for group in optimizer.param_groups: + group.setdefault('initial_lr', group['lr']) + + self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups] + + self.epoch = 0 + self.batch = 0 + + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return {'base_lrs': self.base_lrs, + 'epoch': self.epoch, + 'batch': self.batch} + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_lr(self) -> List[float]: + """ Return last computed learning rate by current scheduler. Will be a list of float. + """ + return self._last_lr + + def get_lr(self): + # Compute list of learning rates from self.epoch and self.batch and + # self.base_lrs; this must be overloaded by the user. + # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] + raise NotImplementedError + + + def step_batch(self, batch: Optional[int] = None) -> None: + # Step the batch index, or just set it. If `batch` is specified, it + # must be the batch index from the start of training, i.e. summed over + # all epochs. + # You can call this in any order; if you don't provide 'batch', it should + # of course be called once per batch. + if batch is not None: + self.batch = batch + else: + self.batch = self.batch + 1 + self._set_lrs() + + def step_epoch(self, epoch: Optional[int] = None): + # Step the epoch index, or just set it. If you provide the 'epoch' arg, + # you should call this at the start of the epoch; if you don't provide the 'epoch' + # arg, you should call it at the end of the epoch. + if epoch is not None: + self.epoch = epoch + else: + self.epoch = self.epoch + 1 + self._set_lrs() + + + def _set_lrs(self): + values = self.get_lr() + assert len(values) == len(self.optimizer.param_groups) + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group['lr'] = lr + self.print_lr(self.verbose, i, lr) + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + + + def print_lr(self, is_verbose, group, lr): + """Display the current learning rate. + """ + if is_verbose: + print(f'Epoch={self.epoch}, batch={self.batch}: adjusting learning rate' + f' of group {group} to {lr:.4e}.') + + +class Eden(LRScheduler): + """ + Eden scheduler. + lr = initial_lr = (((batch**2 + lr_batches**2) / lr_batchses**2) ** -0.25 * + (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) + + E.g. suggest initial-lr = 0.003 (passed to optimizer). + + Args: + optimizer: the optimizer to change the learning rates on + lr_batches: the number of batches after which we start significantly + decreasing the learning rate, suggest 5000. + lr_epochs: the number of epochs after which we start significantly + decreasing the learning rate, suggest 6. + """ + def __init__(self, optimizer: Optimizer, + lr_batches: Union[int, float], + lr_epochs: Union[int, float], + verbose: bool = False): + super(Eden, self).__init__(optimizer, verbose) + self.lr_batches = lr_batches + self.lr_epochs = lr_epochs + + def get_lr(self): + factor = (((self.batch**2 + self.lr_batches**2) / self.lr_batches**2) ** -0.25 * + (((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25)) + return [ x * factor for x in self.base_lrs ] + +def _test_eden(): + m = torch.nn.Linear(100, 100) + optim = Eve(m.parameters(), lr=0.003) + + scheduler = Eden(optim, lr_batches=30, lr_epochs=2, verbose=True) + + for epoch in range(10): + scheduler.step_epoch(epoch) # sets epoch to `epoch` + + for step in range(20): + x = torch.randn(200, 100).detach() + x.requires_grad = True + y = m(x) + dy = torch.randn(200, 100).detach() + f = (y * dy).sum() + f.backward() + + optim.step() + scheduler.step_batch() + optim.zero_grad() + print("last lr = ", scheduler.get_last_lr()) + print("state dict = ", scheduler.state_dict()) + +if __name__ == '__main__': + _test_eden() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 73ba17a71..ddd2e8fb7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -40,7 +40,7 @@ import math import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union import k2 import sentencepiece as spm @@ -55,7 +55,7 @@ from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer -from optim import Eve +from optim import Eve, Eden from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -74,6 +74,7 @@ from icefall.utils import ( str2bool, ) +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): parser = argparse.ArgumentParser( @@ -152,7 +153,7 @@ def get_parser(): ) parser.add_argument( - "--lr-steps", + "--lr-batches", type=float, default=5000, help="""Number of steps that affects how rapidly the learning rate decreases. @@ -378,7 +379,7 @@ def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + scheduler: Optional[LRSchedulerType] = None, ) -> Optional[Dict[str, Any]]: """Load checkpoint from file. @@ -443,7 +444,7 @@ def save_checkpoint( params: AttributeDict, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, rank: int = 0, ) -> None: @@ -593,7 +594,7 @@ def train_one_epoch( params: AttributeDict, model: nn.Module, optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler._LRScheduler, + scheduler: LRSchedulerType, sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, @@ -656,17 +657,15 @@ def train_one_epoch( # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. loss.backward() + scheduler.step_batch(params.batch_idx_train) optimizer.step() optimizer.zero_grad() - scheduler.step() if params.print_diagnostics and batch_idx == 5: return - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): + if (params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0): params.cur_batch_idx = batch_idx save_checkpoint_with_global_batch_idx( out_dir=params.exp_dir, @@ -686,13 +685,17 @@ def train_one_epoch( ) if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] logging.info( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}" + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" ) if tb_writer is not None: + tb_writer.add_scalar("train/learning_rate", cur_lr) + loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) @@ -784,14 +787,7 @@ def run(rank, world_size, args): model.parameters(), lr=params.initial_lr) - # The `epoch` variable in the lambda expression picks up to the value below - # in `for epoch in range(params.start_epoch, params.num_epochs):`. Set it to 0 - # here to avoid crash in constructor. - epoch = 0 - scheduler = torch.optim.lr_scheduler.LambdaLR( - optimizer, - lambda step: (((step**2 + params.lr_steps**2) / params.lr_steps**2) ** -0.25 * - (((epoch**2 + params.lr_epochs**2) / params.lr_epochs**2) ** -0.25))) + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) if checkpoints and "optimizer" in checkpoints: @@ -854,19 +850,14 @@ def run(rank, world_size, args): ) for epoch in range(params.start_epoch, params.num_epochs): + scheduler.step_epoch(epoch) fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) cur_lr = scheduler.get_last_lr()[0] if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - if rank == 0: - logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) - params.cur_epoch = epoch train_one_epoch( diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index 251456c95..c0d4b3968 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -28,15 +28,18 @@ from lhotse.dataset.sampling.base import CutSampler from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler +# use duck typing for LRScheduler since we have different possibilities, see +# our class LRScheduler. +LRSchedulerType = object + def save_checkpoint( filename: Path, model: Union[nn.Module, DDP], params: Optional[Dict[str, Any]] = None, optimizer: Optional[Optimizer] = None, - scheduler: Optional[_LRScheduler] = None, + scheduler: Optional[LRSchedulerType] = None, scaler: Optional[GradScaler] = None, sampler: Optional[CutSampler] = None, rank: int = 0, @@ -89,7 +92,7 @@ def load_checkpoint( filename: Path, model: nn.Module, optimizer: Optional[Optimizer] = None, - scheduler: Optional[_LRScheduler] = None, + scheduler: Optional[LRSchedulerType] = None, scaler: Optional[GradScaler] = None, sampler: Optional[CutSampler] = None, strict: bool = False, @@ -167,7 +170,7 @@ def save_checkpoint_with_global_batch_idx( model: Union[nn.Module, DDP], params: Optional[Dict[str, Any]] = None, optimizer: Optional[Optimizer] = None, - scheduler: Optional[_LRScheduler] = None, + scheduler: Optional[LRSchedulerType] = None, scaler: Optional[GradScaler] = None, sampler: Optional[CutSampler] = None, rank: int = 0,