mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 20:12:24 +00:00
Add Foam optimizer; I used this from epoch 3.
This commit is contained in:
parent
d045831a4f
commit
ccf7bdec23
@ -811,6 +811,140 @@ class Moam(object):
|
|||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
class Foam(object):
|
||||||
|
"""
|
||||||
|
Implements Foam optimizer. This is a modified version of the Noam optimizer
|
||||||
|
which was proposed in "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf,
|
||||||
|
but changed to use Madam (see above) instead of Adam as the base optimizer, and then
|
||||||
|
to change the learning rate schedule and how it is specified.
|
||||||
|
|
||||||
|
|
||||||
|
This code was modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
|
||||||
|
|
||||||
|
warm_step: number of warmup steps before the learning rate starts to decrease
|
||||||
|
(it increases until this point).
|
||||||
|
max_lrate: The learning rate at its maximum, on step `warm_step`
|
||||||
|
knee_factor: The multiple of `max_lrate` after which the learning rate will
|
||||||
|
start to decrease more like 1/x. It increases linearly from 0 to
|
||||||
|
`warm_step`, then decreases approximately as 1/sqrt(x) from
|
||||||
|
`warm_step` to `warm_step * knee_factor`, then decreases
|
||||||
|
approximately as 1/x from `warm_step * knee_factor` onwards.
|
||||||
|
|
||||||
|
min_target_rms: this is a parameter of the Madam optimizer; it represents a floor
|
||||||
|
on the "target root-mean-square value" that is used when the initialization
|
||||||
|
of a tensor is zero or below this value. It may be worth optimizing.
|
||||||
|
Don't worry about tensors with fewer than 2 dimensions when setting this,
|
||||||
|
these are not subject to our l2 formula.
|
||||||
|
limit_grad_factor: Another parameter of Madam, you can set this to a finite
|
||||||
|
value, e.g. 2.0, to activate a mechanism that limits the norms of
|
||||||
|
larger-than-usual gradients. This seems to cause a slowdown, likely due
|
||||||
|
to GPU->CPU transfers, and it is disabled by setting it to infinity.
|
||||||
|
l2_period: mechanism to improve the optimization speed, by only applying the l2
|
||||||
|
regularization (which is a complicated formula) every this-many
|
||||||
|
minibatches. E.g. can set it to 2 or 4.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
params,
|
||||||
|
max_lrate: float = 5.0e-04,
|
||||||
|
warm_step: int = 25000,
|
||||||
|
knee_factor: float = 8.0,
|
||||||
|
min_target_rms: float = 0.05,
|
||||||
|
limit_grad_factor: float = float('inf'),
|
||||||
|
l2_period: int = 1) -> None:
|
||||||
|
"""Construct an Noam object."""
|
||||||
|
self.optimizer = Madam(params, lr=0, betas=(0.9, 0.98), eps=1e-9,
|
||||||
|
min_target_rms=min_target_rms,
|
||||||
|
limit_grad_factor=limit_grad_factor,
|
||||||
|
l2_period=l2_period)
|
||||||
|
self._step = 0
|
||||||
|
|
||||||
|
self._max_lrate = max_lrate
|
||||||
|
self._warm_step = warm_step
|
||||||
|
self._knee_factor = knee_factor
|
||||||
|
self._rate = 0
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def param_groups(self):
|
||||||
|
"""Return param_groups."""
|
||||||
|
return self.optimizer.param_groups
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
"""Update parameters and rate."""
|
||||||
|
self._step += 1
|
||||||
|
rate = self.rate()
|
||||||
|
for p in self.optimizer.param_groups:
|
||||||
|
p["lr"] = rate
|
||||||
|
self._rate = rate
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
|
||||||
|
def rate(self, step=None):
|
||||||
|
"""
|
||||||
|
Suppose the step of optimization is 's', i.e. with s = 0, 1, 2...
|
||||||
|
We define 't = s / warm_step', i.e. t is the step s, normalized so that it
|
||||||
|
is 1.0 at warm_step. Our formula for the learning rate as a function of
|
||||||
|
t is:
|
||||||
|
rate = max_lrate * (t <= 1.0 ? t :
|
||||||
|
sqrt((2 + alpha) / (1 + t + alpha t^2)))
|
||||||
|
where alpha is chosen so that the 't' and 'alpha t^2' terms are identical
|
||||||
|
at t == knee_factor (this means alpha = 1.0/knee_factor). So the
|
||||||
|
learning rate increases linearly from t=00 to t=1, and decreases
|
||||||
|
after that. You can see
|
||||||
|
that sqrt((2 + alpha) / (1 + t + alpha t^2))) is 1.0 when t == 1,
|
||||||
|
which is why the line and the curve meet at that point.
|
||||||
|
|
||||||
|
On the denominator of that ratio, the "t" term makes it decrease a
|
||||||
|
bit like 1/sqrt(t) in 1 <= t <= warm_step; the "alpha t^2" term
|
||||||
|
makes it decrease a bit like 1/t for t > warm_step; and the "1"
|
||||||
|
term makes it decrease a bit slower than 1/sqrt(t) when t is quite
|
||||||
|
close to 1.0 (so we linger a little, near the maximum learning rate).
|
||||||
|
|
||||||
|
This learning rate schedule ultimately decreases more aggressively
|
||||||
|
than Noam, i.e. as 1 / t instead of 1 / sqrt(t). The reason we
|
||||||
|
feel this will work better in conjunction with Madam, is that Madam
|
||||||
|
keeps the norms of the parameters approximately constant throughout
|
||||||
|
training; whereas with Noam, if there is no weight decay, these
|
||||||
|
norms tend to increase as training progresses (although rather
|
||||||
|
unevenly across different parameter tensors).
|
||||||
|
As the norms of the parameters increase, the relative changes
|
||||||
|
in parameters get smaller (the step sizes don't change because
|
||||||
|
Adam normalizes the gradient magnitudes; they'd get smaller otherwise).
|
||||||
|
So Noam doesn't have to decrease the learning rate too aggressively
|
||||||
|
because even with a fixed learning rate, the effective learning rate
|
||||||
|
would be decreasing (again, this only applies without weight decay).
|
||||||
|
"""
|
||||||
|
if step is None:
|
||||||
|
step = self._step
|
||||||
|
t = step / self._warm_step # floating point division.. t is the normalized step.
|
||||||
|
alpha = 1.0 / self._knee_factor
|
||||||
|
return self._max_lrate * (t if t <= 1.0 else
|
||||||
|
((2 + alpha) / (1 + t + alpha * t * t)) ** 0.5)
|
||||||
|
|
||||||
|
def zero_grad(self):
|
||||||
|
"""Reset gradient."""
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
"""Return state_dict."""
|
||||||
|
return {
|
||||||
|
"_step": self._step,
|
||||||
|
}
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
"""Load state_dict. This is compatible with reading a Moam state_dict"""
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
if key == "optimizer":
|
||||||
|
self.optimizer.load_state_dict(state_dict["optimizer"])
|
||||||
|
elif key == '_step':
|
||||||
|
self._step = value
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TestModel(torch.nn.Module):
|
class TestModel(torch.nn.Module):
|
||||||
"""Class for testing the Madam optimizer"""
|
"""Class for testing the Madam optimizer"""
|
||||||
@ -844,9 +978,9 @@ def test_madam():
|
|||||||
inf_grad_max_count = 200
|
inf_grad_max_count = 200
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
devices_and_l2 = [(torch.device('cuda'), True),
|
devices_and_l2 = [(torch.device('cuda'), True),
|
||||||
(torch.device('cuda'), False)]
|
(torch.device('cuda'), False),
|
||||||
#(torch.device('cpu'), True),
|
(torch.device('cpu'), True),
|
||||||
#(torch.device('cpu'), False)]
|
(torch.device('cpu'), False)]
|
||||||
else:
|
else:
|
||||||
devices_and_l2 = [(torch.device('cpu'), True),
|
devices_and_l2 = [(torch.device('cpu'), True),
|
||||||
(torch.device('cpu'), False)]
|
(torch.device('cpu'), False)]
|
||||||
@ -922,6 +1056,48 @@ def test_moam():
|
|||||||
print("")
|
print("")
|
||||||
|
|
||||||
|
|
||||||
|
def test_foam():
|
||||||
|
print("Testing Foam optimizer")
|
||||||
|
model = TestModel()
|
||||||
|
# min_target_rms=0.01 is for testing, so the target equals the initial RMS
|
||||||
|
# and we can more easily tell whether our update has the desired effect.
|
||||||
|
optimizer = Foam(model.parameters(),
|
||||||
|
max_lrate=1.0e-03, warm_step=300,
|
||||||
|
min_target_rms=0.01,
|
||||||
|
limit_grad_factor=4.0)
|
||||||
|
|
||||||
|
|
||||||
|
def get_elems_rms(x: Tensor) -> Tensor:
|
||||||
|
return ((x ** 2).sum() / x.numel()).sqrt().item()
|
||||||
|
|
||||||
|
for i in range(1000):
|
||||||
|
if i % 100 == 0:
|
||||||
|
rms_values = (get_elems_rms(model.first_layers[0].weight),
|
||||||
|
get_elems_rms(model.first_layers[2].weight),
|
||||||
|
get_elems_rms(model.conv1.weight),
|
||||||
|
get_elems_rms(model.conv2.weight))
|
||||||
|
print(f"Iter {i} (Foam): stddevs = {rms_values} ")
|
||||||
|
B = 4
|
||||||
|
T = 20
|
||||||
|
x = torch.randn(B, T, 100)
|
||||||
|
y = model(x)
|
||||||
|
yderiv = torch.randn_like(y)
|
||||||
|
if i % 190 <= 3 and i > 0:
|
||||||
|
yderiv *= 100.0
|
||||||
|
if i % 550 == 0 and i > 0:
|
||||||
|
yderiv *= float('inf')
|
||||||
|
|
||||||
|
y.backward(gradient=yderiv)
|
||||||
|
optimizer.step()
|
||||||
|
model.zero_grad()
|
||||||
|
print("")
|
||||||
|
|
||||||
|
state_dict = optimizer.state_dict()
|
||||||
|
step = optimizer._step
|
||||||
|
optimizer._step = 0
|
||||||
|
optimizer.load_state_dict(state_dict)
|
||||||
|
assert optimizer._step == step
|
||||||
|
|
||||||
|
|
||||||
def test_to_device():
|
def test_to_device():
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
@ -951,8 +1127,10 @@ def main():
|
|||||||
#test_to_device()
|
#test_to_device()
|
||||||
random.seed(0)
|
random.seed(0)
|
||||||
torch.random.manual_seed(0)
|
torch.random.manual_seed(0)
|
||||||
|
test_foam()
|
||||||
|
test_moam()
|
||||||
test_madam()
|
test_madam()
|
||||||
#test_moam()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -1,13 +1,34 @@
|
|||||||
import dataset
|
import k2
|
||||||
import torch
|
import torch
|
||||||
|
import _k2
|
||||||
|
import dataset
|
||||||
|
import os
|
||||||
|
from torch import multiprocessing as mp
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
def local_collate_fn(sentences):
|
||||||
|
return dataset.collate_fn(sentences, bos_sym=1, eos_sym=1, blank_sym=0, debug=True)
|
||||||
|
|
||||||
train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt')
|
x = _k2.RaggedInt('[[1]]') # make sure library initialized?
|
||||||
sampler = dataset.LmBatchSampler(test, symbols_per_batch=1000, world_size=2, rank=0)
|
|
||||||
a = iter(sampler)
|
|
||||||
print(str(next(a)))
|
|
||||||
|
|
||||||
collate_fn=(lambda x:dataset.collate_fn(x, bos_sym=1, eos_sym=1, blank_sym=0, debug=True))
|
if __name__ == '__main__':
|
||||||
train_dl = torch.utils.data.DataLoader(test, batch_sampler=sampler, collate_fn=collate_fn)
|
|
||||||
x = iter(train_dl)
|
#mp.set_start_method('spawn')
|
||||||
print(str(next(x)))
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
|
os.environ["MASTER_PORT"] = "12344"
|
||||||
|
|
||||||
|
dist.init_process_group(backend="nccl", group_name="main",
|
||||||
|
rank=0, world_size=1)
|
||||||
|
|
||||||
|
train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt')
|
||||||
|
sampler = dataset.LmBatchSampler(test, symbols_per_batch=5000, world_size=2, rank=0)
|
||||||
|
print("len(sampler) = ", len(sampler))
|
||||||
|
|
||||||
|
a = iter(sampler)
|
||||||
|
print(str(next(a)))
|
||||||
|
|
||||||
|
train_dl = torch.utils.data.DataLoader(test, batch_sampler=sampler,
|
||||||
|
collate_fn=local_collate_fn,
|
||||||
|
num_workers=2)
|
||||||
|
x = iter(train_dl)
|
||||||
|
print(str(next(x)))
|
||||||
|
@ -35,7 +35,7 @@ from lhotse.utils import fix_random_seed
|
|||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from madam import Moam
|
from madam import Foam
|
||||||
|
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
@ -138,7 +138,7 @@ def get_params() -> AttributeDict:
|
|||||||
"blank_sym": 0,
|
"blank_sym": 0,
|
||||||
"bos_sym": 1,
|
"bos_sym": 1,
|
||||||
"eos_sym": 1,
|
"eos_sym": 1,
|
||||||
"start_epoch": 0,
|
"start_epoch": 3,
|
||||||
"num_epochs": 20,
|
"num_epochs": 20,
|
||||||
"num_valid_batches": 200,
|
"num_valid_batches": 200,
|
||||||
"symbols_per_batch": 5000,
|
"symbols_per_batch": 5000,
|
||||||
@ -155,8 +155,7 @@ def get_params() -> AttributeDict:
|
|||||||
"attention_dim": 512,
|
"attention_dim": 512,
|
||||||
"nhead": 8,
|
"nhead": 8,
|
||||||
"num_decoder_layers": 6,
|
"num_decoder_layers": 6,
|
||||||
"lr_factor": 2.0,
|
"max_lrate": 5.0e-04
|
||||||
"warm_step": 20000,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -520,11 +519,9 @@ def run(rank, world_size, args):
|
|||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
model = DDP(model, device_ids=[rank])
|
model = DDP(model, device_ids=[rank])
|
||||||
|
|
||||||
optimizer = Moam(
|
optimizer = Foam(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
model_size=params.attention_dim,
|
max_lrate=params.max_lrate
|
||||||
factor=params.lr_factor,
|
|
||||||
warm_step=params.warm_step,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if checkpoints:
|
if checkpoints:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user