mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Add Foam optimizer; I used this from epoch 3.
This commit is contained in:
parent
d045831a4f
commit
ccf7bdec23
@ -811,6 +811,140 @@ class Moam(object):
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class Foam(object):
|
||||
"""
|
||||
Implements Foam optimizer. This is a modified version of the Noam optimizer
|
||||
which was proposed in "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf,
|
||||
but changed to use Madam (see above) instead of Adam as the base optimizer, and then
|
||||
to change the learning rate schedule and how it is specified.
|
||||
|
||||
|
||||
This code was modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py
|
||||
|
||||
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
|
||||
|
||||
warm_step: number of warmup steps before the learning rate starts to decrease
|
||||
(it increases until this point).
|
||||
max_lrate: The learning rate at its maximum, on step `warm_step`
|
||||
knee_factor: The multiple of `max_lrate` after which the learning rate will
|
||||
start to decrease more like 1/x. It increases linearly from 0 to
|
||||
`warm_step`, then decreases approximately as 1/sqrt(x) from
|
||||
`warm_step` to `warm_step * knee_factor`, then decreases
|
||||
approximately as 1/x from `warm_step * knee_factor` onwards.
|
||||
|
||||
min_target_rms: this is a parameter of the Madam optimizer; it represents a floor
|
||||
on the "target root-mean-square value" that is used when the initialization
|
||||
of a tensor is zero or below this value. It may be worth optimizing.
|
||||
Don't worry about tensors with fewer than 2 dimensions when setting this,
|
||||
these are not subject to our l2 formula.
|
||||
limit_grad_factor: Another parameter of Madam, you can set this to a finite
|
||||
value, e.g. 2.0, to activate a mechanism that limits the norms of
|
||||
larger-than-usual gradients. This seems to cause a slowdown, likely due
|
||||
to GPU->CPU transfers, and it is disabled by setting it to infinity.
|
||||
l2_period: mechanism to improve the optimization speed, by only applying the l2
|
||||
regularization (which is a complicated formula) every this-many
|
||||
minibatches. E.g. can set it to 2 or 4.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
params,
|
||||
max_lrate: float = 5.0e-04,
|
||||
warm_step: int = 25000,
|
||||
knee_factor: float = 8.0,
|
||||
min_target_rms: float = 0.05,
|
||||
limit_grad_factor: float = float('inf'),
|
||||
l2_period: int = 1) -> None:
|
||||
"""Construct an Noam object."""
|
||||
self.optimizer = Madam(params, lr=0, betas=(0.9, 0.98), eps=1e-9,
|
||||
min_target_rms=min_target_rms,
|
||||
limit_grad_factor=limit_grad_factor,
|
||||
l2_period=l2_period)
|
||||
self._step = 0
|
||||
|
||||
self._max_lrate = max_lrate
|
||||
self._warm_step = warm_step
|
||||
self._knee_factor = knee_factor
|
||||
self._rate = 0
|
||||
|
||||
|
||||
@property
|
||||
def param_groups(self):
|
||||
"""Return param_groups."""
|
||||
return self.optimizer.param_groups
|
||||
|
||||
def step(self):
|
||||
"""Update parameters and rate."""
|
||||
self._step += 1
|
||||
rate = self.rate()
|
||||
for p in self.optimizer.param_groups:
|
||||
p["lr"] = rate
|
||||
self._rate = rate
|
||||
self.optimizer.step()
|
||||
|
||||
|
||||
def rate(self, step=None):
|
||||
"""
|
||||
Suppose the step of optimization is 's', i.e. with s = 0, 1, 2...
|
||||
We define 't = s / warm_step', i.e. t is the step s, normalized so that it
|
||||
is 1.0 at warm_step. Our formula for the learning rate as a function of
|
||||
t is:
|
||||
rate = max_lrate * (t <= 1.0 ? t :
|
||||
sqrt((2 + alpha) / (1 + t + alpha t^2)))
|
||||
where alpha is chosen so that the 't' and 'alpha t^2' terms are identical
|
||||
at t == knee_factor (this means alpha = 1.0/knee_factor). So the
|
||||
learning rate increases linearly from t=00 to t=1, and decreases
|
||||
after that. You can see
|
||||
that sqrt((2 + alpha) / (1 + t + alpha t^2))) is 1.0 when t == 1,
|
||||
which is why the line and the curve meet at that point.
|
||||
|
||||
On the denominator of that ratio, the "t" term makes it decrease a
|
||||
bit like 1/sqrt(t) in 1 <= t <= warm_step; the "alpha t^2" term
|
||||
makes it decrease a bit like 1/t for t > warm_step; and the "1"
|
||||
term makes it decrease a bit slower than 1/sqrt(t) when t is quite
|
||||
close to 1.0 (so we linger a little, near the maximum learning rate).
|
||||
|
||||
This learning rate schedule ultimately decreases more aggressively
|
||||
than Noam, i.e. as 1 / t instead of 1 / sqrt(t). The reason we
|
||||
feel this will work better in conjunction with Madam, is that Madam
|
||||
keeps the norms of the parameters approximately constant throughout
|
||||
training; whereas with Noam, if there is no weight decay, these
|
||||
norms tend to increase as training progresses (although rather
|
||||
unevenly across different parameter tensors).
|
||||
As the norms of the parameters increase, the relative changes
|
||||
in parameters get smaller (the step sizes don't change because
|
||||
Adam normalizes the gradient magnitudes; they'd get smaller otherwise).
|
||||
So Noam doesn't have to decrease the learning rate too aggressively
|
||||
because even with a fixed learning rate, the effective learning rate
|
||||
would be decreasing (again, this only applies without weight decay).
|
||||
"""
|
||||
if step is None:
|
||||
step = self._step
|
||||
t = step / self._warm_step # floating point division.. t is the normalized step.
|
||||
alpha = 1.0 / self._knee_factor
|
||||
return self._max_lrate * (t if t <= 1.0 else
|
||||
((2 + alpha) / (1 + t + alpha * t * t)) ** 0.5)
|
||||
|
||||
def zero_grad(self):
|
||||
"""Reset gradient."""
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
def state_dict(self):
|
||||
"""Return state_dict."""
|
||||
return {
|
||||
"_step": self._step,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Load state_dict. This is compatible with reading a Moam state_dict"""
|
||||
for key, value in state_dict.items():
|
||||
if key == "optimizer":
|
||||
self.optimizer.load_state_dict(state_dict["optimizer"])
|
||||
elif key == '_step':
|
||||
self._step = value
|
||||
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
"""Class for testing the Madam optimizer"""
|
||||
@ -844,9 +978,9 @@ def test_madam():
|
||||
inf_grad_max_count = 200
|
||||
if torch.cuda.is_available():
|
||||
devices_and_l2 = [(torch.device('cuda'), True),
|
||||
(torch.device('cuda'), False)]
|
||||
#(torch.device('cpu'), True),
|
||||
#(torch.device('cpu'), False)]
|
||||
(torch.device('cuda'), False),
|
||||
(torch.device('cpu'), True),
|
||||
(torch.device('cpu'), False)]
|
||||
else:
|
||||
devices_and_l2 = [(torch.device('cpu'), True),
|
||||
(torch.device('cpu'), False)]
|
||||
@ -922,6 +1056,48 @@ def test_moam():
|
||||
print("")
|
||||
|
||||
|
||||
def test_foam():
|
||||
print("Testing Foam optimizer")
|
||||
model = TestModel()
|
||||
# min_target_rms=0.01 is for testing, so the target equals the initial RMS
|
||||
# and we can more easily tell whether our update has the desired effect.
|
||||
optimizer = Foam(model.parameters(),
|
||||
max_lrate=1.0e-03, warm_step=300,
|
||||
min_target_rms=0.01,
|
||||
limit_grad_factor=4.0)
|
||||
|
||||
|
||||
def get_elems_rms(x: Tensor) -> Tensor:
|
||||
return ((x ** 2).sum() / x.numel()).sqrt().item()
|
||||
|
||||
for i in range(1000):
|
||||
if i % 100 == 0:
|
||||
rms_values = (get_elems_rms(model.first_layers[0].weight),
|
||||
get_elems_rms(model.first_layers[2].weight),
|
||||
get_elems_rms(model.conv1.weight),
|
||||
get_elems_rms(model.conv2.weight))
|
||||
print(f"Iter {i} (Foam): stddevs = {rms_values} ")
|
||||
B = 4
|
||||
T = 20
|
||||
x = torch.randn(B, T, 100)
|
||||
y = model(x)
|
||||
yderiv = torch.randn_like(y)
|
||||
if i % 190 <= 3 and i > 0:
|
||||
yderiv *= 100.0
|
||||
if i % 550 == 0 and i > 0:
|
||||
yderiv *= float('inf')
|
||||
|
||||
y.backward(gradient=yderiv)
|
||||
optimizer.step()
|
||||
model.zero_grad()
|
||||
print("")
|
||||
|
||||
state_dict = optimizer.state_dict()
|
||||
step = optimizer._step
|
||||
optimizer._step = 0
|
||||
optimizer.load_state_dict(state_dict)
|
||||
assert optimizer._step == step
|
||||
|
||||
|
||||
def test_to_device():
|
||||
if not torch.cuda.is_available():
|
||||
@ -951,8 +1127,10 @@ def main():
|
||||
#test_to_device()
|
||||
random.seed(0)
|
||||
torch.random.manual_seed(0)
|
||||
test_foam()
|
||||
test_moam()
|
||||
test_madam()
|
||||
#test_moam()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -1,13 +1,34 @@
|
||||
import dataset
|
||||
import k2
|
||||
import torch
|
||||
import _k2
|
||||
import dataset
|
||||
import os
|
||||
from torch import multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
|
||||
def local_collate_fn(sentences):
|
||||
return dataset.collate_fn(sentences, bos_sym=1, eos_sym=1, blank_sym=0, debug=True)
|
||||
|
||||
x = _k2.RaggedInt('[[1]]') # make sure library initialized?
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
#mp.set_start_method('spawn')
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "12344"
|
||||
|
||||
dist.init_process_group(backend="nccl", group_name="main",
|
||||
rank=0, world_size=1)
|
||||
|
||||
train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt')
|
||||
sampler = dataset.LmBatchSampler(test, symbols_per_batch=1000, world_size=2, rank=0)
|
||||
sampler = dataset.LmBatchSampler(test, symbols_per_batch=5000, world_size=2, rank=0)
|
||||
print("len(sampler) = ", len(sampler))
|
||||
|
||||
a = iter(sampler)
|
||||
print(str(next(a)))
|
||||
|
||||
collate_fn=(lambda x:dataset.collate_fn(x, bos_sym=1, eos_sym=1, blank_sym=0, debug=True))
|
||||
train_dl = torch.utils.data.DataLoader(test, batch_sampler=sampler, collate_fn=collate_fn)
|
||||
train_dl = torch.utils.data.DataLoader(test, batch_sampler=sampler,
|
||||
collate_fn=local_collate_fn,
|
||||
num_workers=2)
|
||||
x = iter(train_dl)
|
||||
print(str(next(x)))
|
||||
|
@ -35,7 +35,7 @@ from lhotse.utils import fix_random_seed
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from madam import Moam
|
||||
from madam import Foam
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
@ -138,7 +138,7 @@ def get_params() -> AttributeDict:
|
||||
"blank_sym": 0,
|
||||
"bos_sym": 1,
|
||||
"eos_sym": 1,
|
||||
"start_epoch": 0,
|
||||
"start_epoch": 3,
|
||||
"num_epochs": 20,
|
||||
"num_valid_batches": 200,
|
||||
"symbols_per_batch": 5000,
|
||||
@ -155,8 +155,7 @@ def get_params() -> AttributeDict:
|
||||
"attention_dim": 512,
|
||||
"nhead": 8,
|
||||
"num_decoder_layers": 6,
|
||||
"lr_factor": 2.0,
|
||||
"warm_step": 20000,
|
||||
"max_lrate": 5.0e-04
|
||||
}
|
||||
)
|
||||
|
||||
@ -520,11 +519,9 @@ def run(rank, world_size, args):
|
||||
if world_size > 1:
|
||||
model = DDP(model, device_ids=[rank])
|
||||
|
||||
optimizer = Moam(
|
||||
optimizer = Foam(
|
||||
model.parameters(),
|
||||
model_size=params.attention_dim,
|
||||
factor=params.lr_factor,
|
||||
warm_step=params.warm_step,
|
||||
max_lrate=params.max_lrate
|
||||
)
|
||||
|
||||
if checkpoints:
|
||||
|
Loading…
x
Reference in New Issue
Block a user