Refactor how learning rate is set.

This commit is contained in:
Daniel Povey 2022-04-10 15:25:27 +08:00
parent 82d58629ea
commit d1e4ae788d
3 changed files with 174 additions and 31 deletions

View File

@ -16,7 +16,7 @@
import random import random
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Union
import torch import torch
from torch import Tensor from torch import Tensor
@ -141,3 +141,152 @@ class Eve(Optimizer):
p.addcdiv_(exp_avg, denom, value=-step_size) p.addcdiv_(exp_avg, denom, value=-step_size)
return loss return loss
class LRScheduler(object):
"""
Base-class for learning rate schedulers where the learning-rate depends on both the
batch and the epoch.
"""
def __init__(self, optimizer: Optimizer, verbose: bool = False):
# Attach optimizer
if not isinstance(optimizer, Optimizer):
raise TypeError('{} is not an Optimizer'.format(
type(optimizer).__name__))
self.optimizer = optimizer
self.verbose = verbose
for group in optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups]
self.epoch = 0
self.batch = 0
def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
is not the optimizer.
"""
return {'base_lrs': self.base_lrs,
'epoch': self.epoch,
'batch': self.batch}
def load_state_dict(self, state_dict):
"""Loads the schedulers state.
Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
def get_last_lr(self) -> List[float]:
""" Return last computed learning rate by current scheduler. Will be a list of float.
"""
return self._last_lr
def get_lr(self):
# Compute list of learning rates from self.epoch and self.batch and
# self.base_lrs; this must be overloaded by the user.
# e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
raise NotImplementedError
def step_batch(self, batch: Optional[int] = None) -> None:
# Step the batch index, or just set it. If `batch` is specified, it
# must be the batch index from the start of training, i.e. summed over
# all epochs.
# You can call this in any order; if you don't provide 'batch', it should
# of course be called once per batch.
if batch is not None:
self.batch = batch
else:
self.batch = self.batch + 1
self._set_lrs()
def step_epoch(self, epoch: Optional[int] = None):
# Step the epoch index, or just set it. If you provide the 'epoch' arg,
# you should call this at the start of the epoch; if you don't provide the 'epoch'
# arg, you should call it at the end of the epoch.
if epoch is not None:
self.epoch = epoch
else:
self.epoch = self.epoch + 1
self._set_lrs()
def _set_lrs(self):
values = self.get_lr()
assert len(values) == len(self.optimizer.param_groups)
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
param_group, lr = data
param_group['lr'] = lr
self.print_lr(self.verbose, i, lr)
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
def print_lr(self, is_verbose, group, lr):
"""Display the current learning rate.
"""
if is_verbose:
print(f'Epoch={self.epoch}, batch={self.batch}: adjusting learning rate'
f' of group {group} to {lr:.4e}.')
class Eden(LRScheduler):
"""
Eden scheduler.
lr = initial_lr = (((batch**2 + lr_batches**2) / lr_batchses**2) ** -0.25 *
(((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25))
E.g. suggest initial-lr = 0.003 (passed to optimizer).
Args:
optimizer: the optimizer to change the learning rates on
lr_batches: the number of batches after which we start significantly
decreasing the learning rate, suggest 5000.
lr_epochs: the number of epochs after which we start significantly
decreasing the learning rate, suggest 6.
"""
def __init__(self, optimizer: Optimizer,
lr_batches: Union[int, float],
lr_epochs: Union[int, float],
verbose: bool = False):
super(Eden, self).__init__(optimizer, verbose)
self.lr_batches = lr_batches
self.lr_epochs = lr_epochs
def get_lr(self):
factor = (((self.batch**2 + self.lr_batches**2) / self.lr_batches**2) ** -0.25 *
(((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25))
return [ x * factor for x in self.base_lrs ]
def _test_eden():
m = torch.nn.Linear(100, 100)
optim = Eve(m.parameters(), lr=0.003)
scheduler = Eden(optim, lr_batches=30, lr_epochs=2, verbose=True)
for epoch in range(10):
scheduler.step_epoch(epoch) # sets epoch to `epoch`
for step in range(20):
x = torch.randn(200, 100).detach()
x.requires_grad = True
y = m(x)
dy = torch.randn(200, 100).detach()
f = (y * dy).sum()
f.backward()
optim.step()
scheduler.step_batch()
optim.zero_grad()
print("last lr = ", scheduler.get_last_lr())
print("state dict = ", scheduler.state_dict())
if __name__ == '__main__':
_test_eden()

View File

@ -40,7 +40,7 @@ import math
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple, Union
import k2 import k2
import sentencepiece as spm import sentencepiece as spm
@ -55,7 +55,7 @@ from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eve from optim import Eve, Eden
from torch import Tensor from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -74,6 +74,7 @@ from icefall.utils import (
str2bool, str2bool,
) )
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -152,7 +153,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--lr-steps", "--lr-batches",
type=float, type=float,
default=5000, default=5000,
help="""Number of steps that affects how rapidly the learning rate decreases. help="""Number of steps that affects how rapidly the learning rate decreases.
@ -378,7 +379,7 @@ def load_checkpoint_if_available(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, scheduler: Optional[LRSchedulerType] = None,
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
"""Load checkpoint from file. """Load checkpoint from file.
@ -443,7 +444,7 @@ def save_checkpoint(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None, sampler: Optional[CutSampler] = None,
rank: int = 0, rank: int = 0,
) -> None: ) -> None:
@ -593,7 +594,7 @@ def train_one_epoch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler._LRScheduler, scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
@ -656,17 +657,15 @@ def train_one_epoch(
# NOTE: We use reduction==sum and loss is computed over utterances # NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far. # in the batch and there is no normalization to it so far.
loss.backward() loss.backward()
scheduler.step_batch(params.batch_idx_train)
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
scheduler.step()
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
return return
if ( if (params.batch_idx_train > 0
params.batch_idx_train > 0 and params.batch_idx_train % params.save_every_n == 0):
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
@ -686,13 +685,17 @@ def train_one_epoch(
) )
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
logging.info( logging.info(
f"Epoch {params.cur_epoch}, " f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], " f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}" f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}"
) )
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar("train/learning_rate", cur_lr)
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
@ -784,14 +787,7 @@ def run(rank, world_size, args):
model.parameters(), model.parameters(),
lr=params.initial_lr) lr=params.initial_lr)
# The `epoch` variable in the lambda expression picks up to the value below scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
# in `for epoch in range(params.start_epoch, params.num_epochs):`. Set it to 0
# here to avoid crash in constructor.
epoch = 0
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer,
lambda step: (((step**2 + params.lr_steps**2) / params.lr_steps**2) ** -0.25 *
(((epoch**2 + params.lr_epochs**2) / params.lr_epochs**2) ** -0.25)))
if checkpoints and "optimizer" in checkpoints: if checkpoints and "optimizer" in checkpoints:
@ -854,19 +850,14 @@ def run(rank, world_size, args):
) )
for epoch in range(params.start_epoch, params.num_epochs): for epoch in range(params.start_epoch, params.num_epochs):
scheduler.step_epoch(epoch)
fix_random_seed(params.seed + epoch) fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch) train_dl.sampler.set_epoch(epoch)
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
params.cur_epoch = epoch params.cur_epoch = epoch
train_one_epoch( train_one_epoch(

View File

@ -28,15 +28,18 @@ from lhotse.dataset.sampling.base import CutSampler
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
# use duck typing for LRScheduler since we have different possibilities, see
# our class LRScheduler.
LRSchedulerType = object
def save_checkpoint( def save_checkpoint(
filename: Path, filename: Path,
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
optimizer: Optional[Optimizer] = None, optimizer: Optional[Optimizer] = None,
scheduler: Optional[_LRScheduler] = None, scheduler: Optional[LRSchedulerType] = None,
scaler: Optional[GradScaler] = None, scaler: Optional[GradScaler] = None,
sampler: Optional[CutSampler] = None, sampler: Optional[CutSampler] = None,
rank: int = 0, rank: int = 0,
@ -89,7 +92,7 @@ def load_checkpoint(
filename: Path, filename: Path,
model: nn.Module, model: nn.Module,
optimizer: Optional[Optimizer] = None, optimizer: Optional[Optimizer] = None,
scheduler: Optional[_LRScheduler] = None, scheduler: Optional[LRSchedulerType] = None,
scaler: Optional[GradScaler] = None, scaler: Optional[GradScaler] = None,
sampler: Optional[CutSampler] = None, sampler: Optional[CutSampler] = None,
strict: bool = False, strict: bool = False,
@ -167,7 +170,7 @@ def save_checkpoint_with_global_batch_idx(
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
optimizer: Optional[Optimizer] = None, optimizer: Optional[Optimizer] = None,
scheduler: Optional[_LRScheduler] = None, scheduler: Optional[LRSchedulerType] = None,
scaler: Optional[GradScaler] = None, scaler: Optional[GradScaler] = None,
sampler: Optional[CutSampler] = None, sampler: Optional[CutSampler] = None,
rank: int = 0, rank: int = 0,