Add Foam optimizer; I used this from epoch 3.

This commit is contained in:
Daniel Povey 2021-08-28 21:51:54 +08:00
parent d045831a4f
commit ccf7bdec23
3 changed files with 217 additions and 21 deletions

View File

@ -811,6 +811,140 @@ class Moam(object):
setattr(self, key, value) setattr(self, key, value)
class Foam(object):
"""
Implements Foam optimizer. This is a modified version of the Noam optimizer
which was proposed in "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf,
but changed to use Madam (see above) instead of Adam as the base optimizer, and then
to change the learning rate schedule and how it is specified.
This code was modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py
Args:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
warm_step: number of warmup steps before the learning rate starts to decrease
(it increases until this point).
max_lrate: The learning rate at its maximum, on step `warm_step`
knee_factor: The multiple of `max_lrate` after which the learning rate will
start to decrease more like 1/x. It increases linearly from 0 to
`warm_step`, then decreases approximately as 1/sqrt(x) from
`warm_step` to `warm_step * knee_factor`, then decreases
approximately as 1/x from `warm_step * knee_factor` onwards.
min_target_rms: this is a parameter of the Madam optimizer; it represents a floor
on the "target root-mean-square value" that is used when the initialization
of a tensor is zero or below this value. It may be worth optimizing.
Don't worry about tensors with fewer than 2 dimensions when setting this,
these are not subject to our l2 formula.
limit_grad_factor: Another parameter of Madam, you can set this to a finite
value, e.g. 2.0, to activate a mechanism that limits the norms of
larger-than-usual gradients. This seems to cause a slowdown, likely due
to GPU->CPU transfers, and it is disabled by setting it to infinity.
l2_period: mechanism to improve the optimization speed, by only applying the l2
regularization (which is a complicated formula) every this-many
minibatches. E.g. can set it to 2 or 4.
"""
def __init__(self,
params,
max_lrate: float = 5.0e-04,
warm_step: int = 25000,
knee_factor: float = 8.0,
min_target_rms: float = 0.05,
limit_grad_factor: float = float('inf'),
l2_period: int = 1) -> None:
"""Construct an Noam object."""
self.optimizer = Madam(params, lr=0, betas=(0.9, 0.98), eps=1e-9,
min_target_rms=min_target_rms,
limit_grad_factor=limit_grad_factor,
l2_period=l2_period)
self._step = 0
self._max_lrate = max_lrate
self._warm_step = warm_step
self._knee_factor = knee_factor
self._rate = 0
@property
def param_groups(self):
"""Return param_groups."""
return self.optimizer.param_groups
def step(self):
"""Update parameters and rate."""
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p["lr"] = rate
self._rate = rate
self.optimizer.step()
def rate(self, step=None):
"""
Suppose the step of optimization is 's', i.e. with s = 0, 1, 2...
We define 't = s / warm_step', i.e. t is the step s, normalized so that it
is 1.0 at warm_step. Our formula for the learning rate as a function of
t is:
rate = max_lrate * (t <= 1.0 ? t :
sqrt((2 + alpha) / (1 + t + alpha t^2)))
where alpha is chosen so that the 't' and 'alpha t^2' terms are identical
at t == knee_factor (this means alpha = 1.0/knee_factor). So the
learning rate increases linearly from t=00 to t=1, and decreases
after that. You can see
that sqrt((2 + alpha) / (1 + t + alpha t^2))) is 1.0 when t == 1,
which is why the line and the curve meet at that point.
On the denominator of that ratio, the "t" term makes it decrease a
bit like 1/sqrt(t) in 1 <= t <= warm_step; the "alpha t^2" term
makes it decrease a bit like 1/t for t > warm_step; and the "1"
term makes it decrease a bit slower than 1/sqrt(t) when t is quite
close to 1.0 (so we linger a little, near the maximum learning rate).
This learning rate schedule ultimately decreases more aggressively
than Noam, i.e. as 1 / t instead of 1 / sqrt(t). The reason we
feel this will work better in conjunction with Madam, is that Madam
keeps the norms of the parameters approximately constant throughout
training; whereas with Noam, if there is no weight decay, these
norms tend to increase as training progresses (although rather
unevenly across different parameter tensors).
As the norms of the parameters increase, the relative changes
in parameters get smaller (the step sizes don't change because
Adam normalizes the gradient magnitudes; they'd get smaller otherwise).
So Noam doesn't have to decrease the learning rate too aggressively
because even with a fixed learning rate, the effective learning rate
would be decreasing (again, this only applies without weight decay).
"""
if step is None:
step = self._step
t = step / self._warm_step # floating point division.. t is the normalized step.
alpha = 1.0 / self._knee_factor
return self._max_lrate * (t if t <= 1.0 else
((2 + alpha) / (1 + t + alpha * t * t)) ** 0.5)
def zero_grad(self):
"""Reset gradient."""
self.optimizer.zero_grad()
def state_dict(self):
"""Return state_dict."""
return {
"_step": self._step,
}
def load_state_dict(self, state_dict):
"""Load state_dict. This is compatible with reading a Moam state_dict"""
for key, value in state_dict.items():
if key == "optimizer":
self.optimizer.load_state_dict(state_dict["optimizer"])
elif key == '_step':
self._step = value
class TestModel(torch.nn.Module): class TestModel(torch.nn.Module):
"""Class for testing the Madam optimizer""" """Class for testing the Madam optimizer"""
@ -844,9 +978,9 @@ def test_madam():
inf_grad_max_count = 200 inf_grad_max_count = 200
if torch.cuda.is_available(): if torch.cuda.is_available():
devices_and_l2 = [(torch.device('cuda'), True), devices_and_l2 = [(torch.device('cuda'), True),
(torch.device('cuda'), False)] (torch.device('cuda'), False),
#(torch.device('cpu'), True), (torch.device('cpu'), True),
#(torch.device('cpu'), False)] (torch.device('cpu'), False)]
else: else:
devices_and_l2 = [(torch.device('cpu'), True), devices_and_l2 = [(torch.device('cpu'), True),
(torch.device('cpu'), False)] (torch.device('cpu'), False)]
@ -922,6 +1056,48 @@ def test_moam():
print("") print("")
def test_foam():
print("Testing Foam optimizer")
model = TestModel()
# min_target_rms=0.01 is for testing, so the target equals the initial RMS
# and we can more easily tell whether our update has the desired effect.
optimizer = Foam(model.parameters(),
max_lrate=1.0e-03, warm_step=300,
min_target_rms=0.01,
limit_grad_factor=4.0)
def get_elems_rms(x: Tensor) -> Tensor:
return ((x ** 2).sum() / x.numel()).sqrt().item()
for i in range(1000):
if i % 100 == 0:
rms_values = (get_elems_rms(model.first_layers[0].weight),
get_elems_rms(model.first_layers[2].weight),
get_elems_rms(model.conv1.weight),
get_elems_rms(model.conv2.weight))
print(f"Iter {i} (Foam): stddevs = {rms_values} ")
B = 4
T = 20
x = torch.randn(B, T, 100)
y = model(x)
yderiv = torch.randn_like(y)
if i % 190 <= 3 and i > 0:
yderiv *= 100.0
if i % 550 == 0 and i > 0:
yderiv *= float('inf')
y.backward(gradient=yderiv)
optimizer.step()
model.zero_grad()
print("")
state_dict = optimizer.state_dict()
step = optimizer._step
optimizer._step = 0
optimizer.load_state_dict(state_dict)
assert optimizer._step == step
def test_to_device(): def test_to_device():
if not torch.cuda.is_available(): if not torch.cuda.is_available():
@ -951,8 +1127,10 @@ def main():
#test_to_device() #test_to_device()
random.seed(0) random.seed(0)
torch.random.manual_seed(0) torch.random.manual_seed(0)
test_foam()
test_moam()
test_madam() test_madam()
#test_moam()
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,13 +1,34 @@
import dataset import k2
import torch import torch
import _k2
import dataset
import os
from torch import multiprocessing as mp
import torch.distributed as dist
def local_collate_fn(sentences):
return dataset.collate_fn(sentences, bos_sym=1, eos_sym=1, blank_sym=0, debug=True)
train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt') x = _k2.RaggedInt('[[1]]') # make sure library initialized?
sampler = dataset.LmBatchSampler(test, symbols_per_batch=1000, world_size=2, rank=0)
a = iter(sampler)
print(str(next(a)))
collate_fn=(lambda x:dataset.collate_fn(x, bos_sym=1, eos_sym=1, blank_sym=0, debug=True)) if __name__ == '__main__':
train_dl = torch.utils.data.DataLoader(test, batch_sampler=sampler, collate_fn=collate_fn)
x = iter(train_dl) #mp.set_start_method('spawn')
print(str(next(x))) os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12344"
dist.init_process_group(backend="nccl", group_name="main",
rank=0, world_size=1)
train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt')
sampler = dataset.LmBatchSampler(test, symbols_per_batch=5000, world_size=2, rank=0)
print("len(sampler) = ", len(sampler))
a = iter(sampler)
print(str(next(a)))
train_dl = torch.utils.data.DataLoader(test, batch_sampler=sampler,
collate_fn=local_collate_fn,
num_workers=2)
x = iter(train_dl)
print(str(next(x)))

View File

@ -35,7 +35,7 @@ from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from madam import Moam from madam import Foam
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
@ -138,7 +138,7 @@ def get_params() -> AttributeDict:
"blank_sym": 0, "blank_sym": 0,
"bos_sym": 1, "bos_sym": 1,
"eos_sym": 1, "eos_sym": 1,
"start_epoch": 0, "start_epoch": 3,
"num_epochs": 20, "num_epochs": 20,
"num_valid_batches": 200, "num_valid_batches": 200,
"symbols_per_batch": 5000, "symbols_per_batch": 5000,
@ -155,8 +155,7 @@ def get_params() -> AttributeDict:
"attention_dim": 512, "attention_dim": 512,
"nhead": 8, "nhead": 8,
"num_decoder_layers": 6, "num_decoder_layers": 6,
"lr_factor": 2.0, "max_lrate": 5.0e-04
"warm_step": 20000,
} }
) )
@ -520,11 +519,9 @@ def run(rank, world_size, args):
if world_size > 1: if world_size > 1:
model = DDP(model, device_ids=[rank]) model = DDP(model, device_ids=[rank])
optimizer = Moam( optimizer = Foam(
model.parameters(), model.parameters(),
model_size=params.attention_dim, max_lrate=params.max_lrate
factor=params.lr_factor,
warm_step=params.warm_step,
) )
if checkpoints: if checkpoints: