From 771583de0f21b10829920a434fdf46eb0ab6800c Mon Sep 17 00:00:00 2001 From: Yifan Yang Date: Thu, 1 Jun 2023 16:37:18 +0800 Subject: [PATCH] add optim --- icefall/rnn_lm/optim.py | 184 ++++++++++++++++++++++++++++++++++++++++ icefall/rnn_lm/train.py | 22 ++++- 2 files changed, 204 insertions(+), 2 deletions(-) create mode 100644 icefall/rnn_lm/optim.py diff --git a/icefall/rnn_lm/optim.py b/icefall/rnn_lm/optim.py new file mode 100644 index 000000000..0ceea7153 --- /dev/null +++ b/icefall/rnn_lm/optim.py @@ -0,0 +1,184 @@ +# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from collections import defaultdict +from typing import List, Optional, Tuple, Union + +import torch +from torch import Tensor +from torch.optim import Optimizer + + +class LRScheduler(object): + """ + Base-class for learning rate schedulers where the learning-rate depends on both the + batch and the epoch. + """ + + def __init__(self, optimizer: Optimizer, verbose: bool = False): + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) + self.optimizer = optimizer + self.verbose = verbose + + for group in optimizer.param_groups: + group.setdefault("base_lr", group["lr"]) + + self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] + + self.epoch = 0 + self.batch = 0 + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return { + "base_lrs": self.base_lrs, + "epoch": self.epoch, + "batch": self.batch, + } + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_lr(self) -> List[float]: + """Return last computed learning rate by current scheduler. Will be a list of float.""" + return self._last_lr + + def get_lr(self): + # Compute list of learning rates from self.epoch and self.batch and + # self.base_lrs; this must be overloaded by the user. + # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] + raise NotImplementedError + + def step_batch(self, batch: Optional[int] = None) -> None: + # Step the batch index, or just set it. If `batch` is specified, it + # must be the batch index from the start of training, i.e. summed over + # all epochs. + # You can call this in any order; if you don't provide 'batch', it should + # of course be called once per batch. + if batch is not None: + self.batch = batch + else: + self.batch = self.batch + 1 + self._set_lrs() + + def step_epoch(self, epoch: Optional[int] = None): + # Step the epoch index, or just set it. If you provide the 'epoch' arg, + # you should call this at the start of the epoch; if you don't provide the 'epoch' + # arg, you should call it at the end of the epoch. + if epoch is not None: + self.epoch = epoch + else: + self.epoch = self.epoch + 1 + self._set_lrs() + + def _set_lrs(self): + values = self.get_lr() + assert len(values) == len(self.optimizer.param_groups) + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group["lr"] = lr + self.print_lr(self.verbose, i, lr) + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + + def print_lr(self, is_verbose, group, lr): + """Display the current learning rate.""" + if is_verbose: + logging.info( + f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" + f" of group {group} to {lr:.4e}." + ) + + +class NewBobScheduler(LRScheduler): + """ + New-Bob Scheduler + The basic formula is: + lr = lr * annealing_factor if (prev_metric - current_metric) / prev_metric < threshold + where metric is training loss. + + Args: + optimizer: the optimizer to change the learning rates on + annealing_factor: the annealing factor used in new_bob strategy. + threshold: the rate between losses used to perform learning annealing in new_bob strategy. + patient: when the annealing condition is violated patient times, the learning rate is finally reduced. + """ + + def __init__( + self, + optimizer: Optimizer, + annealing_factor: float = 0.5, + threshold: float = 0.0025, + patient: int = 0, + verbose: bool = False, + ): + super(NewBobScheduler, self).__init__(optimizer, verbose) + self.annealing_factor = annealing_factor + self.threshold = threshold + self.patient = patient + self.current_patient = self.patient + self.prev_metric = None + self.current_metric = None + + def step_batch(self, current_metric: Tensor) -> None: + self.current_metric = current_metric + self._set_lrs() + + def get_lr(self): + """Returns the new lr. + + Args: + metric: A number for determining whether to change the lr value. + """ + factor = 1 + if self.prev_metric is not None: + if self.prev_metric == 0: + improvement = 0 + else: + improvement = ( + self.prev_metric - self.current_metric + ) / self.prev_metric + if improvement < self.threshold: + if self.current_patient == 0: + factor = self.annealing_factor + self.current_patient = self.patient + else: + self.current_patient -= 1 + + self.prev_metric = self.current_metric + + return [x * factor for x in self.base_lrs] + + def state_dict(self): + return { + "base_lrs": self.base_lrs, + "prev_metric": self.prev_metric, + "current_metric": current_metric, + "current_patient": self.current_patient, + } diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py index c860688c4..af7b29096 100755 --- a/icefall/rnn_lm/train.py +++ b/icefall/rnn_lm/train.py @@ -40,6 +40,7 @@ import torch.optim as optim from dataset import get_dataloader from lhotse.utils import fix_random_seed from model import RnnLmModel +from optim import NewBobScheduler from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter @@ -449,6 +450,7 @@ def train_one_epoch( params: AttributeDict, model: nn.Module, optimizer: torch.optim.Optimizer, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler], train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, model_avg: nn.Module = None, @@ -471,6 +473,8 @@ def train_one_epoch( The stored model averaged from the start of training. optimizer: The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. train_dl: Dataloader for the training dataset. valid_dl: @@ -500,6 +504,7 @@ def train_one_epoch( # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + scheduler.step_batch(loss) optimizer.zero_grad() loss.backward() clip_grad_norm_(model.parameters(), 5.0, 2.0) @@ -527,6 +532,7 @@ def train_one_epoch( model_avg=model_avg, params=params, optimizer=optimizer, + scheduler=scheduler, rank=rank, ) @@ -534,11 +540,12 @@ def train_one_epoch( # Note: "frames" here means "num_tokens" this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"]) tot_ppl = math.exp(tot_loss["loss"] / tot_loss["frames"]) - + cur_lr = scheduler.get_last_lr()[0] logging.info( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}, ppl: {this_batch_ppl}] " f"tot_loss[{tot_loss}, ppl: {tot_ppl}], " + f"lr: {cur_lr:.2e}, " f"batch size: {batch_size}" ) @@ -656,10 +663,20 @@ def run(rank, world_size, args): lr=params.lr, weight_decay=params.weight_decay, ) + scheduler = NewBobScheduler(optimizer) + if checkpoints: logging.info("Load optimizer state_dict from checkpoint") optimizer.load_state_dict(checkpoints["optimizer"]) + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + logging.info(f"Loading LM training data from {params.lm_data}") train_dl = get_dataloader( filename=params.lm_data, @@ -674,7 +691,6 @@ def run(rank, world_size, args): params=params, ) - # Note: No learning rate scheduler is used here for epoch in range(params.start_epoch, params.num_epochs + 1): if is_distributed: train_dl.sampler.set_epoch(epoch - 1) @@ -686,6 +702,7 @@ def run(rank, world_size, args): model=model, model_avg=model_avg, optimizer=optimizer, + scheduler=scheduler, train_dl=train_dl, valid_dl=valid_dl, tb_writer=tb_writer, @@ -698,6 +715,7 @@ def run(rank, world_size, args): model=model, model_avg=model_avg, optimizer=optimizer, + scheduler=scheduler, rank=rank, )