mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Refactor how learning rate is set.
This commit is contained in:
parent
82d58629ea
commit
d1e4ae788d
@ -16,7 +16,7 @@
|
||||
|
||||
|
||||
import random
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -141,3 +141,152 @@ class Eve(Optimizer):
|
||||
p.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
return loss
|
||||
|
||||
class LRScheduler(object):
|
||||
"""
|
||||
Base-class for learning rate schedulers where the learning-rate depends on both the
|
||||
batch and the epoch.
|
||||
"""
|
||||
def __init__(self, optimizer: Optimizer, verbose: bool = False):
|
||||
# Attach optimizer
|
||||
if not isinstance(optimizer, Optimizer):
|
||||
raise TypeError('{} is not an Optimizer'.format(
|
||||
type(optimizer).__name__))
|
||||
self.optimizer = optimizer
|
||||
self.verbose = verbose
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
group.setdefault('initial_lr', group['lr'])
|
||||
|
||||
self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups]
|
||||
|
||||
self.epoch = 0
|
||||
self.batch = 0
|
||||
|
||||
|
||||
def state_dict(self):
|
||||
"""Returns the state of the scheduler as a :class:`dict`.
|
||||
|
||||
It contains an entry for every variable in self.__dict__ which
|
||||
is not the optimizer.
|
||||
"""
|
||||
return {'base_lrs': self.base_lrs,
|
||||
'epoch': self.epoch,
|
||||
'batch': self.batch}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Loads the schedulers state.
|
||||
|
||||
Args:
|
||||
state_dict (dict): scheduler state. Should be an object returned
|
||||
from a call to :meth:`state_dict`.
|
||||
"""
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
def get_last_lr(self) -> List[float]:
|
||||
""" Return last computed learning rate by current scheduler. Will be a list of float.
|
||||
"""
|
||||
return self._last_lr
|
||||
|
||||
def get_lr(self):
|
||||
# Compute list of learning rates from self.epoch and self.batch and
|
||||
# self.base_lrs; this must be overloaded by the user.
|
||||
# e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def step_batch(self, batch: Optional[int] = None) -> None:
|
||||
# Step the batch index, or just set it. If `batch` is specified, it
|
||||
# must be the batch index from the start of training, i.e. summed over
|
||||
# all epochs.
|
||||
# You can call this in any order; if you don't provide 'batch', it should
|
||||
# of course be called once per batch.
|
||||
if batch is not None:
|
||||
self.batch = batch
|
||||
else:
|
||||
self.batch = self.batch + 1
|
||||
self._set_lrs()
|
||||
|
||||
def step_epoch(self, epoch: Optional[int] = None):
|
||||
# Step the epoch index, or just set it. If you provide the 'epoch' arg,
|
||||
# you should call this at the start of the epoch; if you don't provide the 'epoch'
|
||||
# arg, you should call it at the end of the epoch.
|
||||
if epoch is not None:
|
||||
self.epoch = epoch
|
||||
else:
|
||||
self.epoch = self.epoch + 1
|
||||
self._set_lrs()
|
||||
|
||||
|
||||
def _set_lrs(self):
|
||||
values = self.get_lr()
|
||||
assert len(values) == len(self.optimizer.param_groups)
|
||||
|
||||
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
|
||||
param_group, lr = data
|
||||
param_group['lr'] = lr
|
||||
self.print_lr(self.verbose, i, lr)
|
||||
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
|
||||
|
||||
|
||||
def print_lr(self, is_verbose, group, lr):
|
||||
"""Display the current learning rate.
|
||||
"""
|
||||
if is_verbose:
|
||||
print(f'Epoch={self.epoch}, batch={self.batch}: adjusting learning rate'
|
||||
f' of group {group} to {lr:.4e}.')
|
||||
|
||||
|
||||
class Eden(LRScheduler):
|
||||
"""
|
||||
Eden scheduler.
|
||||
lr = initial_lr = (((batch**2 + lr_batches**2) / lr_batchses**2) ** -0.25 *
|
||||
(((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25))
|
||||
|
||||
E.g. suggest initial-lr = 0.003 (passed to optimizer).
|
||||
|
||||
Args:
|
||||
optimizer: the optimizer to change the learning rates on
|
||||
lr_batches: the number of batches after which we start significantly
|
||||
decreasing the learning rate, suggest 5000.
|
||||
lr_epochs: the number of epochs after which we start significantly
|
||||
decreasing the learning rate, suggest 6.
|
||||
"""
|
||||
def __init__(self, optimizer: Optimizer,
|
||||
lr_batches: Union[int, float],
|
||||
lr_epochs: Union[int, float],
|
||||
verbose: bool = False):
|
||||
super(Eden, self).__init__(optimizer, verbose)
|
||||
self.lr_batches = lr_batches
|
||||
self.lr_epochs = lr_epochs
|
||||
|
||||
def get_lr(self):
|
||||
factor = (((self.batch**2 + self.lr_batches**2) / self.lr_batches**2) ** -0.25 *
|
||||
(((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25))
|
||||
return [ x * factor for x in self.base_lrs ]
|
||||
|
||||
def _test_eden():
|
||||
m = torch.nn.Linear(100, 100)
|
||||
optim = Eve(m.parameters(), lr=0.003)
|
||||
|
||||
scheduler = Eden(optim, lr_batches=30, lr_epochs=2, verbose=True)
|
||||
|
||||
for epoch in range(10):
|
||||
scheduler.step_epoch(epoch) # sets epoch to `epoch`
|
||||
|
||||
for step in range(20):
|
||||
x = torch.randn(200, 100).detach()
|
||||
x.requires_grad = True
|
||||
y = m(x)
|
||||
dy = torch.randn(200, 100).detach()
|
||||
f = (y * dy).sum()
|
||||
f.backward()
|
||||
|
||||
optim.step()
|
||||
scheduler.step_batch()
|
||||
optim.zero_grad()
|
||||
print("last lr = ", scheduler.get_last_lr())
|
||||
print("state dict = ", scheduler.state_dict())
|
||||
|
||||
if __name__ == '__main__':
|
||||
_test_eden()
|
||||
|
@ -40,7 +40,7 @@ import math
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
@ -55,7 +55,7 @@ from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eve
|
||||
from optim import Eve, Eden
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
@ -74,6 +74,7 @@ from icefall.utils import (
|
||||
str2bool,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -152,7 +153,7 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lr-steps",
|
||||
"--lr-batches",
|
||||
type=float,
|
||||
default=5000,
|
||||
help="""Number of steps that affects how rapidly the learning rate decreases.
|
||||
@ -378,7 +379,7 @@ def load_checkpoint_if_available(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Load checkpoint from file.
|
||||
|
||||
@ -443,7 +444,7 @@ def save_checkpoint(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
@ -593,7 +594,7 @@ def train_one_epoch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scheduler: torch.optim.lr_scheduler._LRScheduler,
|
||||
scheduler: LRSchedulerType,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
@ -656,17 +657,15 @@ def train_one_epoch(
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
loss.backward()
|
||||
scheduler.step_batch(params.batch_idx_train)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
scheduler.step()
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
return
|
||||
|
||||
if (
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
if (params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
@ -686,13 +685,17 @@ def train_one_epoch(
|
||||
)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
f"batch {batch_idx}, loss[{loss_info}], "
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}"
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||
f"lr: {cur_lr:.2e}"
|
||||
)
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar("train/learning_rate", cur_lr)
|
||||
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
@ -784,14 +787,7 @@ def run(rank, world_size, args):
|
||||
model.parameters(),
|
||||
lr=params.initial_lr)
|
||||
|
||||
# The `epoch` variable in the lambda expression picks up to the value below
|
||||
# in `for epoch in range(params.start_epoch, params.num_epochs):`. Set it to 0
|
||||
# here to avoid crash in constructor.
|
||||
epoch = 0
|
||||
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||
optimizer,
|
||||
lambda step: (((step**2 + params.lr_steps**2) / params.lr_steps**2) ** -0.25 *
|
||||
(((epoch**2 + params.lr_epochs**2) / params.lr_epochs**2) ** -0.25)))
|
||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||
|
||||
|
||||
if checkpoints and "optimizer" in checkpoints:
|
||||
@ -854,19 +850,14 @@ def run(rank, world_size, args):
|
||||
)
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs):
|
||||
scheduler.step_epoch(epoch)
|
||||
fix_random_seed(params.seed + epoch)
|
||||
train_dl.sampler.set_epoch(epoch)
|
||||
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||
)
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
|
||||
if rank == 0:
|
||||
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
|
||||
|
||||
params.cur_epoch = epoch
|
||||
|
||||
train_one_epoch(
|
||||
|
@ -28,15 +28,18 @@ from lhotse.dataset.sampling.base import CutSampler
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
|
||||
# use duck typing for LRScheduler since we have different possibilities, see
|
||||
# our class LRScheduler.
|
||||
LRSchedulerType = object
|
||||
|
||||
def save_checkpoint(
|
||||
filename: Path,
|
||||
model: Union[nn.Module, DDP],
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
scheduler: Optional[_LRScheduler] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
rank: int = 0,
|
||||
@ -89,7 +92,7 @@ def load_checkpoint(
|
||||
filename: Path,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
scheduler: Optional[_LRScheduler] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
strict: bool = False,
|
||||
@ -167,7 +170,7 @@ def save_checkpoint_with_global_batch_idx(
|
||||
model: Union[nn.Module, DDP],
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
scheduler: Optional[_LRScheduler] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
rank: int = 0,
|
||||
|
Loading…
x
Reference in New Issue
Block a user