mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Move to Gloam optimizer, exponential lrate
This commit is contained in:
parent
d313c27c14
commit
56a88badd1
@ -813,11 +813,6 @@ class Moam(object):
|
|||||||
|
|
||||||
class Foam(object):
|
class Foam(object):
|
||||||
"""
|
"""
|
||||||
Implements Foam optimizer. This is a modified version of the Noam optimizer
|
|
||||||
which was proposed in "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf,
|
|
||||||
but changed to use Madam (see above) instead of Adam as the base optimizer, and then
|
|
||||||
to change the learning rate schedule and how it is specified.
|
|
||||||
|
|
||||||
|
|
||||||
This code was modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py
|
This code was modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py
|
||||||
|
|
||||||
@ -946,6 +941,153 @@ class Foam(object):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Gloam(object):
|
||||||
|
"""
|
||||||
|
Implements Gloam optimizer. This is a modified version of the Noam optimizer
|
||||||
|
which was proposed in "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf,
|
||||||
|
but changed to use Madam (see above) instead of Adam as the base optimizer, and then
|
||||||
|
to change the learning rate schedule and how it is specified. We have
|
||||||
|
a warm-up stage, but after it gets to `max_lrate` it stays constant for the
|
||||||
|
rest of the 1st epoch, and after that, only changes on epoch boundaries.
|
||||||
|
|
||||||
|
CAUTION: you have to call set_epoch() every epoch, to set the epoch. If you don't do this,
|
||||||
|
this won't work!
|
||||||
|
|
||||||
|
|
||||||
|
This code was modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
|
||||||
|
|
||||||
|
warm_step: number of warmup steps before the learning rate starts to decrease
|
||||||
|
(it increases until this point).
|
||||||
|
max_lrate: The learning rate at its maximum, on step `warm_step`
|
||||||
|
first_decrease_epoch: The epoch number on which to start decreasing the
|
||||||
|
learning rate.
|
||||||
|
decay_per_epoch:
|
||||||
|
min_target_rms: this is a parameter of the Madam optimizer; it represents a floor
|
||||||
|
on the "target root-mean-square value" that is used when the initialization
|
||||||
|
of a tensor is zero or below this value. It may be worth optimizing.
|
||||||
|
Don't worry about tensors with fewer than 2 dimensions when setting this,
|
||||||
|
these are not subject to our l2 formula.
|
||||||
|
limit_grad_factor: Another parameter of Madam, you can set this to a finite
|
||||||
|
value, e.g. 2.0, to activate a mechanism that limits the norms of
|
||||||
|
larger-than-usual gradients. This seems to cause a slowdown, likely due
|
||||||
|
to GPU->CPU transfers, and it is disabled by setting it to infinity.
|
||||||
|
l2_period: mechanism to improve the optimization speed, by only applying the l2
|
||||||
|
regularization (which is a complicated formula) every this-many
|
||||||
|
minibatches. E.g. can set it to 2 or 4.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
params,
|
||||||
|
max_lrate: float = 5.0e-04,
|
||||||
|
warm_step: int = 25000,
|
||||||
|
first_decrease_epoch: int = 1,
|
||||||
|
decay_per_epoch: float = 0.85,
|
||||||
|
min_target_rms: float = 0.05,
|
||||||
|
limit_grad_factor: float = float('inf'),
|
||||||
|
l2_period: int = 1) -> None:
|
||||||
|
"""Construct an Noam object."""
|
||||||
|
self.optimizer = Madam(params, lr=0, betas=(0.9, 0.98), eps=1e-9,
|
||||||
|
min_target_rms=min_target_rms,
|
||||||
|
limit_grad_factor=limit_grad_factor,
|
||||||
|
l2_period=l2_period)
|
||||||
|
self._step = 0
|
||||||
|
|
||||||
|
self._max_lrate = max_lrate
|
||||||
|
self._warm_step = warm_step
|
||||||
|
self._first_decrease_epoch = first_decrease_epoch
|
||||||
|
self._decay_per_epoch = decay_per_epoch
|
||||||
|
self._rate = 0
|
||||||
|
self._epoch = 0
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def param_groups(self):
|
||||||
|
"""Return param_groups."""
|
||||||
|
return self.optimizer.param_groups
|
||||||
|
|
||||||
|
def set_epoch(self, epoch: int):
|
||||||
|
self._epoch = epoch
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
"""Update parameters and rate."""
|
||||||
|
self._step += 1
|
||||||
|
rate = self.rate()
|
||||||
|
for p in self.optimizer.param_groups:
|
||||||
|
p["lr"] = rate
|
||||||
|
self._rate = rate
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
|
||||||
|
def rate(self, step=None):
|
||||||
|
"""
|
||||||
|
Suppose the step of optimization is 's', i.e. with s = 0, 1, 2...
|
||||||
|
We define 't = s / warm_step', i.e. t is the step s, normalized so that it
|
||||||
|
is 1.0 at warm_step. Our formula for the learning rate as a function of
|
||||||
|
t is:
|
||||||
|
rate = max_lrate * (t <= 1.0 ? t :
|
||||||
|
sqrt((2 + alpha) / (1 + t + alpha t^2)))
|
||||||
|
where alpha is chosen so that the 't' and 'alpha t^2' terms are identical
|
||||||
|
at t == knee_factor (this means alpha = 1.0/knee_factor). So the
|
||||||
|
learning rate increases linearly from t=00 to t=1, and decreases
|
||||||
|
after that. You can see
|
||||||
|
that sqrt((2 + alpha) / (1 + t + alpha t^2))) is 1.0 when t == 1,
|
||||||
|
which is why the line and the curve meet at that point.
|
||||||
|
|
||||||
|
On the denominator of that ratio, the "t" term makes it decrease a
|
||||||
|
bit like 1/sqrt(t) in 1 <= t <= warm_step; the "alpha t^2" term
|
||||||
|
makes it decrease a bit like 1/t for t > warm_step; and the "1"
|
||||||
|
term makes it decrease a bit slower than 1/sqrt(t) when t is quite
|
||||||
|
close to 1.0 (so we linger a little, near the maximum learning rate).
|
||||||
|
|
||||||
|
This learning rate schedule ultimately decreases more aggressively
|
||||||
|
than Noam, i.e. as 1 / t instead of 1 / sqrt(t). The reason we
|
||||||
|
feel this will work better in conjunction with Madam, is that Madam
|
||||||
|
keeps the norms of the parameters approximately constant throughout
|
||||||
|
training; whereas with Noam, if there is no weight decay, these
|
||||||
|
norms tend to increase as training progresses (although rather
|
||||||
|
unevenly across different parameter tensors).
|
||||||
|
As the norms of the parameters increase, the relative changes
|
||||||
|
in parameters get smaller (the step sizes don't change because
|
||||||
|
Adam normalizes the gradient magnitudes; they'd get smaller otherwise).
|
||||||
|
So Noam doesn't have to decrease the learning rate too aggressively
|
||||||
|
because even with a fixed learning rate, the effective learning rate
|
||||||
|
would be decreasing (again, this only applies without weight decay).
|
||||||
|
"""
|
||||||
|
if step is None:
|
||||||
|
step = self._step
|
||||||
|
t = step / self._warm_step # floating point division.. t is the normalized step.
|
||||||
|
base_rate = self._max_lrate * (t if t <= 1.0 else 1.0)
|
||||||
|
epoch_rate = self._decay_per_epoch ** max(0, self._epoch + 1 - self._first_decrease_epoch)
|
||||||
|
return base_rate * epoch_rate
|
||||||
|
|
||||||
|
|
||||||
|
def zero_grad(self):
|
||||||
|
"""Reset gradient."""
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
"""Return state_dict."""
|
||||||
|
return {
|
||||||
|
"_step": self._step,
|
||||||
|
"_epoch": self._epoch,
|
||||||
|
}
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
"""Load state_dict. This is compatible with reading a Moam state_dict"""
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
if key == "optimizer":
|
||||||
|
self.optimizer.load_state_dict(state_dict["optimizer"])
|
||||||
|
elif key == '_step':
|
||||||
|
self._step = value
|
||||||
|
elif key == '_epoch':
|
||||||
|
self._epoch = value
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TestModel(torch.nn.Module):
|
class TestModel(torch.nn.Module):
|
||||||
"""Class for testing the Madam optimizer"""
|
"""Class for testing the Madam optimizer"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -35,7 +35,7 @@ from lhotse.utils import fix_random_seed
|
|||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from madam import Foam
|
from madam import Gloam
|
||||||
|
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
@ -133,7 +133,8 @@ def get_params() -> AttributeDict:
|
|||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
# exp_3, vs. exp_2, is using 5e-04 not 2d-04 as max learning rate.
|
# exp_3, vs. exp_2, is using 5e-04 not 2d-04 as max learning rate.
|
||||||
"exp_dir": Path("conformer_lm/exp_3"),
|
# exp_4, vs. exp_3, is using the Gloam optimizer with
|
||||||
|
"exp_dir": Path("conformer_lm/exp_4"),
|
||||||
"lm_dataset": Path("data/lm_training_5000/lm_data.pt"),
|
"lm_dataset": Path("data/lm_training_5000/lm_data.pt"),
|
||||||
"num_tokens": 5000,
|
"num_tokens": 5000,
|
||||||
"blank_sym": 0,
|
"blank_sym": 0,
|
||||||
@ -520,9 +521,13 @@ def run(rank, world_size, args):
|
|||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
model = DDP(model, device_ids=[rank])
|
model = DDP(model, device_ids=[rank])
|
||||||
|
|
||||||
optimizer = Foam(
|
# Caution: don't forget to do optimizer.set_epoch() with Gloam!
|
||||||
|
# Don't remove this warning!
|
||||||
|
optimizer = Gloam(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
max_lrate=params.max_lrate
|
max_lrate=params.max_lrate,
|
||||||
|
first_decrease_epoch=2,
|
||||||
|
decay_per_epoch=0.85
|
||||||
)
|
)
|
||||||
|
|
||||||
if checkpoints:
|
if checkpoints:
|
||||||
@ -556,6 +561,8 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
train_sampler.set_epoch(epoch)
|
train_sampler.set_epoch(epoch)
|
||||||
|
optimizer.set_epoch(epoch) # Caution: this is specific to the Gloam
|
||||||
|
# optimizer.
|
||||||
|
|
||||||
cur_lr = optimizer._rate
|
cur_lr = optimizer._rate
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user