mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Refactor how learning rate is set.
This commit is contained in:
parent
82d58629ea
commit
d1e4ae788d
@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@ -141,3 +141,152 @@ class Eve(Optimizer):
|
|||||||
p.addcdiv_(exp_avg, denom, value=-step_size)
|
p.addcdiv_(exp_avg, denom, value=-step_size)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
class LRScheduler(object):
|
||||||
|
"""
|
||||||
|
Base-class for learning rate schedulers where the learning-rate depends on both the
|
||||||
|
batch and the epoch.
|
||||||
|
"""
|
||||||
|
def __init__(self, optimizer: Optimizer, verbose: bool = False):
|
||||||
|
# Attach optimizer
|
||||||
|
if not isinstance(optimizer, Optimizer):
|
||||||
|
raise TypeError('{} is not an Optimizer'.format(
|
||||||
|
type(optimizer).__name__))
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
|
for group in optimizer.param_groups:
|
||||||
|
group.setdefault('initial_lr', group['lr'])
|
||||||
|
|
||||||
|
self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups]
|
||||||
|
|
||||||
|
self.epoch = 0
|
||||||
|
self.batch = 0
|
||||||
|
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
"""Returns the state of the scheduler as a :class:`dict`.
|
||||||
|
|
||||||
|
It contains an entry for every variable in self.__dict__ which
|
||||||
|
is not the optimizer.
|
||||||
|
"""
|
||||||
|
return {'base_lrs': self.base_lrs,
|
||||||
|
'epoch': self.epoch,
|
||||||
|
'batch': self.batch}
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
"""Loads the schedulers state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict (dict): scheduler state. Should be an object returned
|
||||||
|
from a call to :meth:`state_dict`.
|
||||||
|
"""
|
||||||
|
self.__dict__.update(state_dict)
|
||||||
|
|
||||||
|
def get_last_lr(self) -> List[float]:
|
||||||
|
""" Return last computed learning rate by current scheduler. Will be a list of float.
|
||||||
|
"""
|
||||||
|
return self._last_lr
|
||||||
|
|
||||||
|
def get_lr(self):
|
||||||
|
# Compute list of learning rates from self.epoch and self.batch and
|
||||||
|
# self.base_lrs; this must be overloaded by the user.
|
||||||
|
# e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def step_batch(self, batch: Optional[int] = None) -> None:
|
||||||
|
# Step the batch index, or just set it. If `batch` is specified, it
|
||||||
|
# must be the batch index from the start of training, i.e. summed over
|
||||||
|
# all epochs.
|
||||||
|
# You can call this in any order; if you don't provide 'batch', it should
|
||||||
|
# of course be called once per batch.
|
||||||
|
if batch is not None:
|
||||||
|
self.batch = batch
|
||||||
|
else:
|
||||||
|
self.batch = self.batch + 1
|
||||||
|
self._set_lrs()
|
||||||
|
|
||||||
|
def step_epoch(self, epoch: Optional[int] = None):
|
||||||
|
# Step the epoch index, or just set it. If you provide the 'epoch' arg,
|
||||||
|
# you should call this at the start of the epoch; if you don't provide the 'epoch'
|
||||||
|
# arg, you should call it at the end of the epoch.
|
||||||
|
if epoch is not None:
|
||||||
|
self.epoch = epoch
|
||||||
|
else:
|
||||||
|
self.epoch = self.epoch + 1
|
||||||
|
self._set_lrs()
|
||||||
|
|
||||||
|
|
||||||
|
def _set_lrs(self):
|
||||||
|
values = self.get_lr()
|
||||||
|
assert len(values) == len(self.optimizer.param_groups)
|
||||||
|
|
||||||
|
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
|
||||||
|
param_group, lr = data
|
||||||
|
param_group['lr'] = lr
|
||||||
|
self.print_lr(self.verbose, i, lr)
|
||||||
|
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
|
||||||
|
|
||||||
|
|
||||||
|
def print_lr(self, is_verbose, group, lr):
|
||||||
|
"""Display the current learning rate.
|
||||||
|
"""
|
||||||
|
if is_verbose:
|
||||||
|
print(f'Epoch={self.epoch}, batch={self.batch}: adjusting learning rate'
|
||||||
|
f' of group {group} to {lr:.4e}.')
|
||||||
|
|
||||||
|
|
||||||
|
class Eden(LRScheduler):
|
||||||
|
"""
|
||||||
|
Eden scheduler.
|
||||||
|
lr = initial_lr = (((batch**2 + lr_batches**2) / lr_batchses**2) ** -0.25 *
|
||||||
|
(((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25))
|
||||||
|
|
||||||
|
E.g. suggest initial-lr = 0.003 (passed to optimizer).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer: the optimizer to change the learning rates on
|
||||||
|
lr_batches: the number of batches after which we start significantly
|
||||||
|
decreasing the learning rate, suggest 5000.
|
||||||
|
lr_epochs: the number of epochs after which we start significantly
|
||||||
|
decreasing the learning rate, suggest 6.
|
||||||
|
"""
|
||||||
|
def __init__(self, optimizer: Optimizer,
|
||||||
|
lr_batches: Union[int, float],
|
||||||
|
lr_epochs: Union[int, float],
|
||||||
|
verbose: bool = False):
|
||||||
|
super(Eden, self).__init__(optimizer, verbose)
|
||||||
|
self.lr_batches = lr_batches
|
||||||
|
self.lr_epochs = lr_epochs
|
||||||
|
|
||||||
|
def get_lr(self):
|
||||||
|
factor = (((self.batch**2 + self.lr_batches**2) / self.lr_batches**2) ** -0.25 *
|
||||||
|
(((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25))
|
||||||
|
return [ x * factor for x in self.base_lrs ]
|
||||||
|
|
||||||
|
def _test_eden():
|
||||||
|
m = torch.nn.Linear(100, 100)
|
||||||
|
optim = Eve(m.parameters(), lr=0.003)
|
||||||
|
|
||||||
|
scheduler = Eden(optim, lr_batches=30, lr_epochs=2, verbose=True)
|
||||||
|
|
||||||
|
for epoch in range(10):
|
||||||
|
scheduler.step_epoch(epoch) # sets epoch to `epoch`
|
||||||
|
|
||||||
|
for step in range(20):
|
||||||
|
x = torch.randn(200, 100).detach()
|
||||||
|
x.requires_grad = True
|
||||||
|
y = m(x)
|
||||||
|
dy = torch.randn(200, 100).detach()
|
||||||
|
f = (y * dy).sum()
|
||||||
|
f.backward()
|
||||||
|
|
||||||
|
optim.step()
|
||||||
|
scheduler.step_batch()
|
||||||
|
optim.zero_grad()
|
||||||
|
print("last lr = ", scheduler.get_last_lr())
|
||||||
|
print("state dict = ", scheduler.state_dict())
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
_test_eden()
|
||||||
|
@ -40,7 +40,7 @@ import math
|
|||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
@ -55,7 +55,7 @@ from lhotse.cut import Cut
|
|||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
from optim import Eve
|
from optim import Eve, Eden
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
@ -74,6 +74,7 @@ from icefall.utils import (
|
|||||||
str2bool,
|
str2bool,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -152,7 +153,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lr-steps",
|
"--lr-batches",
|
||||||
type=float,
|
type=float,
|
||||||
default=5000,
|
default=5000,
|
||||||
help="""Number of steps that affects how rapidly the learning rate decreases.
|
help="""Number of steps that affects how rapidly the learning rate decreases.
|
||||||
@ -378,7 +379,7 @@ def load_checkpoint_if_available(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
scheduler: Optional[LRSchedulerType] = None,
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
"""Load checkpoint from file.
|
"""Load checkpoint from file.
|
||||||
|
|
||||||
@ -443,7 +444,7 @@ def save_checkpoint(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
scheduler: Optional[LRSchedulerType] = None,
|
||||||
sampler: Optional[CutSampler] = None,
|
sampler: Optional[CutSampler] = None,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -593,7 +594,7 @@ def train_one_epoch(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
scheduler: torch.optim.lr_scheduler._LRScheduler,
|
scheduler: LRSchedulerType,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
@ -656,17 +657,15 @@ def train_one_epoch(
|
|||||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||||
# in the batch and there is no normalization to it so far.
|
# in the batch and there is no normalization to it so far.
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
scheduler.step_batch(params.batch_idx_train)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
scheduler.step()
|
|
||||||
|
|
||||||
if params.print_diagnostics and batch_idx == 5:
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
return
|
return
|
||||||
|
|
||||||
if (
|
if (params.batch_idx_train > 0
|
||||||
params.batch_idx_train > 0
|
and params.batch_idx_train % params.save_every_n == 0):
|
||||||
and params.batch_idx_train % params.save_every_n == 0
|
|
||||||
):
|
|
||||||
params.cur_batch_idx = batch_idx
|
params.cur_batch_idx = batch_idx
|
||||||
save_checkpoint_with_global_batch_idx(
|
save_checkpoint_with_global_batch_idx(
|
||||||
out_dir=params.exp_dir,
|
out_dir=params.exp_dir,
|
||||||
@ -686,13 +685,17 @@ def train_one_epoch(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Epoch {params.cur_epoch}, "
|
f"Epoch {params.cur_epoch}, "
|
||||||
f"batch {batch_idx}, loss[{loss_info}], "
|
f"batch {batch_idx}, loss[{loss_info}], "
|
||||||
f"tot_loss[{tot_loss}], batch size: {batch_size}"
|
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||||
|
f"lr: {cur_lr:.2e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
|
tb_writer.add_scalar("train/learning_rate", cur_lr)
|
||||||
|
|
||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer, "train/current_", params.batch_idx_train
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
@ -784,14 +787,7 @@ def run(rank, world_size, args):
|
|||||||
model.parameters(),
|
model.parameters(),
|
||||||
lr=params.initial_lr)
|
lr=params.initial_lr)
|
||||||
|
|
||||||
# The `epoch` variable in the lambda expression picks up to the value below
|
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||||
# in `for epoch in range(params.start_epoch, params.num_epochs):`. Set it to 0
|
|
||||||
# here to avoid crash in constructor.
|
|
||||||
epoch = 0
|
|
||||||
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
|
||||||
optimizer,
|
|
||||||
lambda step: (((step**2 + params.lr_steps**2) / params.lr_steps**2) ** -0.25 *
|
|
||||||
(((epoch**2 + params.lr_epochs**2) / params.lr_epochs**2) ** -0.25)))
|
|
||||||
|
|
||||||
|
|
||||||
if checkpoints and "optimizer" in checkpoints:
|
if checkpoints and "optimizer" in checkpoints:
|
||||||
@ -854,19 +850,14 @@ def run(rank, world_size, args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
|
scheduler.step_epoch(epoch)
|
||||||
fix_random_seed(params.seed + epoch)
|
fix_random_seed(params.seed + epoch)
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar(
|
|
||||||
"train/learning_rate", cur_lr, params.batch_idx_train
|
|
||||||
)
|
|
||||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||||
|
|
||||||
if rank == 0:
|
|
||||||
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
|
|
||||||
|
|
||||||
params.cur_epoch = epoch
|
params.cur_epoch = epoch
|
||||||
|
|
||||||
train_one_epoch(
|
train_one_epoch(
|
||||||
|
@ -28,15 +28,18 @@ from lhotse.dataset.sampling.base import CutSampler
|
|||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import _LRScheduler
|
|
||||||
|
|
||||||
|
|
||||||
|
# use duck typing for LRScheduler since we have different possibilities, see
|
||||||
|
# our class LRScheduler.
|
||||||
|
LRSchedulerType = object
|
||||||
|
|
||||||
def save_checkpoint(
|
def save_checkpoint(
|
||||||
filename: Path,
|
filename: Path,
|
||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
optimizer: Optional[Optimizer] = None,
|
optimizer: Optional[Optimizer] = None,
|
||||||
scheduler: Optional[_LRScheduler] = None,
|
scheduler: Optional[LRSchedulerType] = None,
|
||||||
scaler: Optional[GradScaler] = None,
|
scaler: Optional[GradScaler] = None,
|
||||||
sampler: Optional[CutSampler] = None,
|
sampler: Optional[CutSampler] = None,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
@ -89,7 +92,7 @@ def load_checkpoint(
|
|||||||
filename: Path,
|
filename: Path,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: Optional[Optimizer] = None,
|
optimizer: Optional[Optimizer] = None,
|
||||||
scheduler: Optional[_LRScheduler] = None,
|
scheduler: Optional[LRSchedulerType] = None,
|
||||||
scaler: Optional[GradScaler] = None,
|
scaler: Optional[GradScaler] = None,
|
||||||
sampler: Optional[CutSampler] = None,
|
sampler: Optional[CutSampler] = None,
|
||||||
strict: bool = False,
|
strict: bool = False,
|
||||||
@ -167,7 +170,7 @@ def save_checkpoint_with_global_batch_idx(
|
|||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
optimizer: Optional[Optimizer] = None,
|
optimizer: Optional[Optimizer] = None,
|
||||||
scheduler: Optional[_LRScheduler] = None,
|
scheduler: Optional[LRSchedulerType] = None,
|
||||||
scaler: Optional[GradScaler] = None,
|
scaler: Optional[GradScaler] = None,
|
||||||
sampler: Optional[CutSampler] = None,
|
sampler: Optional[CutSampler] = None,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user