diff --git a/egs/librispeech/ASR/conformer_lm/madam.py b/egs/librispeech/ASR/conformer_lm/madam.py index 36716efec..bc8168330 100644 --- a/egs/librispeech/ASR/conformer_lm/madam.py +++ b/egs/librispeech/ASR/conformer_lm/madam.py @@ -813,11 +813,6 @@ class Moam(object): class Foam(object): """ - Implements Foam optimizer. This is a modified version of the Noam optimizer - which was proposed in "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf, - but changed to use Madam (see above) instead of Adam as the base optimizer, and then - to change the learning rate schedule and how it is specified. - This code was modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py @@ -946,6 +941,153 @@ class Foam(object): +class Gloam(object): + """ + Implements Gloam optimizer. This is a modified version of the Noam optimizer + which was proposed in "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf, + but changed to use Madam (see above) instead of Adam as the base optimizer, and then + to change the learning rate schedule and how it is specified. We have + a warm-up stage, but after it gets to `max_lrate` it stays constant for the + rest of the 1st epoch, and after that, only changes on epoch boundaries. + + CAUTION: you have to call set_epoch() every epoch, to set the epoch. If you don't do this, + this won't work! + + + This code was modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py + + + Args: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups + + warm_step: number of warmup steps before the learning rate starts to decrease + (it increases until this point). + max_lrate: The learning rate at its maximum, on step `warm_step` + first_decrease_epoch: The epoch number on which to start decreasing the + learning rate. + decay_per_epoch: + min_target_rms: this is a parameter of the Madam optimizer; it represents a floor + on the "target root-mean-square value" that is used when the initialization + of a tensor is zero or below this value. It may be worth optimizing. + Don't worry about tensors with fewer than 2 dimensions when setting this, + these are not subject to our l2 formula. + limit_grad_factor: Another parameter of Madam, you can set this to a finite + value, e.g. 2.0, to activate a mechanism that limits the norms of + larger-than-usual gradients. This seems to cause a slowdown, likely due + to GPU->CPU transfers, and it is disabled by setting it to infinity. + l2_period: mechanism to improve the optimization speed, by only applying the l2 + regularization (which is a complicated formula) every this-many + minibatches. E.g. can set it to 2 or 4. + """ + + def __init__(self, + params, + max_lrate: float = 5.0e-04, + warm_step: int = 25000, + first_decrease_epoch: int = 1, + decay_per_epoch: float = 0.85, + min_target_rms: float = 0.05, + limit_grad_factor: float = float('inf'), + l2_period: int = 1) -> None: + """Construct an Noam object.""" + self.optimizer = Madam(params, lr=0, betas=(0.9, 0.98), eps=1e-9, + min_target_rms=min_target_rms, + limit_grad_factor=limit_grad_factor, + l2_period=l2_period) + self._step = 0 + + self._max_lrate = max_lrate + self._warm_step = warm_step + self._first_decrease_epoch = first_decrease_epoch + self._decay_per_epoch = decay_per_epoch + self._rate = 0 + self._epoch = 0 + + + @property + def param_groups(self): + """Return param_groups.""" + return self.optimizer.param_groups + + def set_epoch(self, epoch: int): + self._epoch = epoch + + def step(self): + """Update parameters and rate.""" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + + def rate(self, step=None): + """ + Suppose the step of optimization is 's', i.e. with s = 0, 1, 2... + We define 't = s / warm_step', i.e. t is the step s, normalized so that it + is 1.0 at warm_step. Our formula for the learning rate as a function of + t is: + rate = max_lrate * (t <= 1.0 ? t : + sqrt((2 + alpha) / (1 + t + alpha t^2))) + where alpha is chosen so that the 't' and 'alpha t^2' terms are identical + at t == knee_factor (this means alpha = 1.0/knee_factor). So the + learning rate increases linearly from t=00 to t=1, and decreases + after that. You can see + that sqrt((2 + alpha) / (1 + t + alpha t^2))) is 1.0 when t == 1, + which is why the line and the curve meet at that point. + + On the denominator of that ratio, the "t" term makes it decrease a + bit like 1/sqrt(t) in 1 <= t <= warm_step; the "alpha t^2" term + makes it decrease a bit like 1/t for t > warm_step; and the "1" + term makes it decrease a bit slower than 1/sqrt(t) when t is quite + close to 1.0 (so we linger a little, near the maximum learning rate). + + This learning rate schedule ultimately decreases more aggressively + than Noam, i.e. as 1 / t instead of 1 / sqrt(t). The reason we + feel this will work better in conjunction with Madam, is that Madam + keeps the norms of the parameters approximately constant throughout + training; whereas with Noam, if there is no weight decay, these + norms tend to increase as training progresses (although rather + unevenly across different parameter tensors). + As the norms of the parameters increase, the relative changes + in parameters get smaller (the step sizes don't change because + Adam normalizes the gradient magnitudes; they'd get smaller otherwise). + So Noam doesn't have to decrease the learning rate too aggressively + because even with a fixed learning rate, the effective learning rate + would be decreasing (again, this only applies without weight decay). + """ + if step is None: + step = self._step + t = step / self._warm_step # floating point division.. t is the normalized step. + base_rate = self._max_lrate * (t if t <= 1.0 else 1.0) + epoch_rate = self._decay_per_epoch ** max(0, self._epoch + 1 - self._first_decrease_epoch) + return base_rate * epoch_rate + + + def zero_grad(self): + """Reset gradient.""" + self.optimizer.zero_grad() + + def state_dict(self): + """Return state_dict.""" + return { + "_step": self._step, + "_epoch": self._epoch, + } + + def load_state_dict(self, state_dict): + """Load state_dict. This is compatible with reading a Moam state_dict""" + for key, value in state_dict.items(): + if key == "optimizer": + self.optimizer.load_state_dict(state_dict["optimizer"]) + elif key == '_step': + self._step = value + elif key == '_epoch': + self._epoch = value + + + class TestModel(torch.nn.Module): """Class for testing the Madam optimizer""" def __init__(self): diff --git a/egs/librispeech/ASR/conformer_lm/train.py b/egs/librispeech/ASR/conformer_lm/train.py index d70d06674..dd35a2d77 100755 --- a/egs/librispeech/ASR/conformer_lm/train.py +++ b/egs/librispeech/ASR/conformer_lm/train.py @@ -35,7 +35,7 @@ from lhotse.utils import fix_random_seed from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter -from madam import Foam +from madam import Gloam from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl @@ -133,7 +133,8 @@ def get_params() -> AttributeDict: params = AttributeDict( { # exp_3, vs. exp_2, is using 5e-04 not 2d-04 as max learning rate. - "exp_dir": Path("conformer_lm/exp_3"), + # exp_4, vs. exp_3, is using the Gloam optimizer with + "exp_dir": Path("conformer_lm/exp_4"), "lm_dataset": Path("data/lm_training_5000/lm_data.pt"), "num_tokens": 5000, "blank_sym": 0, @@ -520,9 +521,13 @@ def run(rank, world_size, args): if world_size > 1: model = DDP(model, device_ids=[rank]) - optimizer = Foam( + # Caution: don't forget to do optimizer.set_epoch() with Gloam! + # Don't remove this warning! + optimizer = Gloam( model.parameters(), - max_lrate=params.max_lrate + max_lrate=params.max_lrate, + first_decrease_epoch=2, + decay_per_epoch=0.85 ) if checkpoints: @@ -556,6 +561,8 @@ def run(rank, world_size, args): for epoch in range(params.start_epoch, params.num_epochs): train_sampler.set_epoch(epoch) + optimizer.set_epoch(epoch) # Caution: this is specific to the Gloam + # optimizer. cur_lr = optimizer._rate if tb_writer is not None: