Move to Gloam optimizer, exponential lrate

This commit is contained in:
Daniel Povey 2021-09-08 13:59:50 +08:00
parent d313c27c14
commit 56a88badd1
2 changed files with 158 additions and 9 deletions

View File

@ -813,11 +813,6 @@ class Moam(object):
class Foam(object): class Foam(object):
""" """
Implements Foam optimizer. This is a modified version of the Noam optimizer
which was proposed in "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf,
but changed to use Madam (see above) instead of Adam as the base optimizer, and then
to change the learning rate schedule and how it is specified.
This code was modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py This code was modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py
@ -946,6 +941,153 @@ class Foam(object):
class Gloam(object):
"""
Implements Gloam optimizer. This is a modified version of the Noam optimizer
which was proposed in "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf,
but changed to use Madam (see above) instead of Adam as the base optimizer, and then
to change the learning rate schedule and how it is specified. We have
a warm-up stage, but after it gets to `max_lrate` it stays constant for the
rest of the 1st epoch, and after that, only changes on epoch boundaries.
CAUTION: you have to call set_epoch() every epoch, to set the epoch. If you don't do this,
this won't work!
This code was modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py
Args:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
warm_step: number of warmup steps before the learning rate starts to decrease
(it increases until this point).
max_lrate: The learning rate at its maximum, on step `warm_step`
first_decrease_epoch: The epoch number on which to start decreasing the
learning rate.
decay_per_epoch:
min_target_rms: this is a parameter of the Madam optimizer; it represents a floor
on the "target root-mean-square value" that is used when the initialization
of a tensor is zero or below this value. It may be worth optimizing.
Don't worry about tensors with fewer than 2 dimensions when setting this,
these are not subject to our l2 formula.
limit_grad_factor: Another parameter of Madam, you can set this to a finite
value, e.g. 2.0, to activate a mechanism that limits the norms of
larger-than-usual gradients. This seems to cause a slowdown, likely due
to GPU->CPU transfers, and it is disabled by setting it to infinity.
l2_period: mechanism to improve the optimization speed, by only applying the l2
regularization (which is a complicated formula) every this-many
minibatches. E.g. can set it to 2 or 4.
"""
def __init__(self,
params,
max_lrate: float = 5.0e-04,
warm_step: int = 25000,
first_decrease_epoch: int = 1,
decay_per_epoch: float = 0.85,
min_target_rms: float = 0.05,
limit_grad_factor: float = float('inf'),
l2_period: int = 1) -> None:
"""Construct an Noam object."""
self.optimizer = Madam(params, lr=0, betas=(0.9, 0.98), eps=1e-9,
min_target_rms=min_target_rms,
limit_grad_factor=limit_grad_factor,
l2_period=l2_period)
self._step = 0
self._max_lrate = max_lrate
self._warm_step = warm_step
self._first_decrease_epoch = first_decrease_epoch
self._decay_per_epoch = decay_per_epoch
self._rate = 0
self._epoch = 0
@property
def param_groups(self):
"""Return param_groups."""
return self.optimizer.param_groups
def set_epoch(self, epoch: int):
self._epoch = epoch
def step(self):
"""Update parameters and rate."""
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p["lr"] = rate
self._rate = rate
self.optimizer.step()
def rate(self, step=None):
"""
Suppose the step of optimization is 's', i.e. with s = 0, 1, 2...
We define 't = s / warm_step', i.e. t is the step s, normalized so that it
is 1.0 at warm_step. Our formula for the learning rate as a function of
t is:
rate = max_lrate * (t <= 1.0 ? t :
sqrt((2 + alpha) / (1 + t + alpha t^2)))
where alpha is chosen so that the 't' and 'alpha t^2' terms are identical
at t == knee_factor (this means alpha = 1.0/knee_factor). So the
learning rate increases linearly from t=00 to t=1, and decreases
after that. You can see
that sqrt((2 + alpha) / (1 + t + alpha t^2))) is 1.0 when t == 1,
which is why the line and the curve meet at that point.
On the denominator of that ratio, the "t" term makes it decrease a
bit like 1/sqrt(t) in 1 <= t <= warm_step; the "alpha t^2" term
makes it decrease a bit like 1/t for t > warm_step; and the "1"
term makes it decrease a bit slower than 1/sqrt(t) when t is quite
close to 1.0 (so we linger a little, near the maximum learning rate).
This learning rate schedule ultimately decreases more aggressively
than Noam, i.e. as 1 / t instead of 1 / sqrt(t). The reason we
feel this will work better in conjunction with Madam, is that Madam
keeps the norms of the parameters approximately constant throughout
training; whereas with Noam, if there is no weight decay, these
norms tend to increase as training progresses (although rather
unevenly across different parameter tensors).
As the norms of the parameters increase, the relative changes
in parameters get smaller (the step sizes don't change because
Adam normalizes the gradient magnitudes; they'd get smaller otherwise).
So Noam doesn't have to decrease the learning rate too aggressively
because even with a fixed learning rate, the effective learning rate
would be decreasing (again, this only applies without weight decay).
"""
if step is None:
step = self._step
t = step / self._warm_step # floating point division.. t is the normalized step.
base_rate = self._max_lrate * (t if t <= 1.0 else 1.0)
epoch_rate = self._decay_per_epoch ** max(0, self._epoch + 1 - self._first_decrease_epoch)
return base_rate * epoch_rate
def zero_grad(self):
"""Reset gradient."""
self.optimizer.zero_grad()
def state_dict(self):
"""Return state_dict."""
return {
"_step": self._step,
"_epoch": self._epoch,
}
def load_state_dict(self, state_dict):
"""Load state_dict. This is compatible with reading a Moam state_dict"""
for key, value in state_dict.items():
if key == "optimizer":
self.optimizer.load_state_dict(state_dict["optimizer"])
elif key == '_step':
self._step = value
elif key == '_epoch':
self._epoch = value
class TestModel(torch.nn.Module): class TestModel(torch.nn.Module):
"""Class for testing the Madam optimizer""" """Class for testing the Madam optimizer"""
def __init__(self): def __init__(self):

View File

@ -35,7 +35,7 @@ from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from madam import Foam from madam import Gloam
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
@ -133,7 +133,8 @@ def get_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
# exp_3, vs. exp_2, is using 5e-04 not 2d-04 as max learning rate. # exp_3, vs. exp_2, is using 5e-04 not 2d-04 as max learning rate.
"exp_dir": Path("conformer_lm/exp_3"), # exp_4, vs. exp_3, is using the Gloam optimizer with
"exp_dir": Path("conformer_lm/exp_4"),
"lm_dataset": Path("data/lm_training_5000/lm_data.pt"), "lm_dataset": Path("data/lm_training_5000/lm_data.pt"),
"num_tokens": 5000, "num_tokens": 5000,
"blank_sym": 0, "blank_sym": 0,
@ -520,9 +521,13 @@ def run(rank, world_size, args):
if world_size > 1: if world_size > 1:
model = DDP(model, device_ids=[rank]) model = DDP(model, device_ids=[rank])
optimizer = Foam( # Caution: don't forget to do optimizer.set_epoch() with Gloam!
# Don't remove this warning!
optimizer = Gloam(
model.parameters(), model.parameters(),
max_lrate=params.max_lrate max_lrate=params.max_lrate,
first_decrease_epoch=2,
decay_per_epoch=0.85
) )
if checkpoints: if checkpoints:
@ -556,6 +561,8 @@ def run(rank, world_size, args):
for epoch in range(params.start_epoch, params.num_epochs): for epoch in range(params.start_epoch, params.num_epochs):
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
optimizer.set_epoch(epoch) # Caution: this is specific to the Gloam
# optimizer.
cur_lr = optimizer._rate cur_lr = optimizer._rate
if tb_writer is not None: if tb_writer is not None: