From ccf7bdec230edb9a833770368b92b74503ae1125 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 28 Aug 2021 21:51:54 +0800 Subject: [PATCH] Add Foam optimizer; I used this from epoch 3. --- egs/librispeech/ASR/conformer_lm/madam.py | 186 +++++++++++++++++- .../ASR/conformer_lm/test_dataset.py | 39 +++- egs/librispeech/ASR/conformer_lm/train.py | 13 +- 3 files changed, 217 insertions(+), 21 deletions(-) diff --git a/egs/librispeech/ASR/conformer_lm/madam.py b/egs/librispeech/ASR/conformer_lm/madam.py index aa605c30b..07266a63b 100644 --- a/egs/librispeech/ASR/conformer_lm/madam.py +++ b/egs/librispeech/ASR/conformer_lm/madam.py @@ -811,6 +811,140 @@ class Moam(object): setattr(self, key, value) +class Foam(object): + """ + Implements Foam optimizer. This is a modified version of the Noam optimizer + which was proposed in "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf, + but changed to use Madam (see above) instead of Adam as the base optimizer, and then + to change the learning rate schedule and how it is specified. + + + This code was modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py + + + Args: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups + + warm_step: number of warmup steps before the learning rate starts to decrease + (it increases until this point). + max_lrate: The learning rate at its maximum, on step `warm_step` + knee_factor: The multiple of `max_lrate` after which the learning rate will + start to decrease more like 1/x. It increases linearly from 0 to + `warm_step`, then decreases approximately as 1/sqrt(x) from + `warm_step` to `warm_step * knee_factor`, then decreases + approximately as 1/x from `warm_step * knee_factor` onwards. + + min_target_rms: this is a parameter of the Madam optimizer; it represents a floor + on the "target root-mean-square value" that is used when the initialization + of a tensor is zero or below this value. It may be worth optimizing. + Don't worry about tensors with fewer than 2 dimensions when setting this, + these are not subject to our l2 formula. + limit_grad_factor: Another parameter of Madam, you can set this to a finite + value, e.g. 2.0, to activate a mechanism that limits the norms of + larger-than-usual gradients. This seems to cause a slowdown, likely due + to GPU->CPU transfers, and it is disabled by setting it to infinity. + l2_period: mechanism to improve the optimization speed, by only applying the l2 + regularization (which is a complicated formula) every this-many + minibatches. E.g. can set it to 2 or 4. + """ + + def __init__(self, + params, + max_lrate: float = 5.0e-04, + warm_step: int = 25000, + knee_factor: float = 8.0, + min_target_rms: float = 0.05, + limit_grad_factor: float = float('inf'), + l2_period: int = 1) -> None: + """Construct an Noam object.""" + self.optimizer = Madam(params, lr=0, betas=(0.9, 0.98), eps=1e-9, + min_target_rms=min_target_rms, + limit_grad_factor=limit_grad_factor, + l2_period=l2_period) + self._step = 0 + + self._max_lrate = max_lrate + self._warm_step = warm_step + self._knee_factor = knee_factor + self._rate = 0 + + + @property + def param_groups(self): + """Return param_groups.""" + return self.optimizer.param_groups + + def step(self): + """Update parameters and rate.""" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + + def rate(self, step=None): + """ + Suppose the step of optimization is 's', i.e. with s = 0, 1, 2... + We define 't = s / warm_step', i.e. t is the step s, normalized so that it + is 1.0 at warm_step. Our formula for the learning rate as a function of + t is: + rate = max_lrate * (t <= 1.0 ? t : + sqrt((2 + alpha) / (1 + t + alpha t^2))) + where alpha is chosen so that the 't' and 'alpha t^2' terms are identical + at t == knee_factor (this means alpha = 1.0/knee_factor). So the + learning rate increases linearly from t=00 to t=1, and decreases + after that. You can see + that sqrt((2 + alpha) / (1 + t + alpha t^2))) is 1.0 when t == 1, + which is why the line and the curve meet at that point. + + On the denominator of that ratio, the "t" term makes it decrease a + bit like 1/sqrt(t) in 1 <= t <= warm_step; the "alpha t^2" term + makes it decrease a bit like 1/t for t > warm_step; and the "1" + term makes it decrease a bit slower than 1/sqrt(t) when t is quite + close to 1.0 (so we linger a little, near the maximum learning rate). + + This learning rate schedule ultimately decreases more aggressively + than Noam, i.e. as 1 / t instead of 1 / sqrt(t). The reason we + feel this will work better in conjunction with Madam, is that Madam + keeps the norms of the parameters approximately constant throughout + training; whereas with Noam, if there is no weight decay, these + norms tend to increase as training progresses (although rather + unevenly across different parameter tensors). + As the norms of the parameters increase, the relative changes + in parameters get smaller (the step sizes don't change because + Adam normalizes the gradient magnitudes; they'd get smaller otherwise). + So Noam doesn't have to decrease the learning rate too aggressively + because even with a fixed learning rate, the effective learning rate + would be decreasing (again, this only applies without weight decay). + """ + if step is None: + step = self._step + t = step / self._warm_step # floating point division.. t is the normalized step. + alpha = 1.0 / self._knee_factor + return self._max_lrate * (t if t <= 1.0 else + ((2 + alpha) / (1 + t + alpha * t * t)) ** 0.5) + + def zero_grad(self): + """Reset gradient.""" + self.optimizer.zero_grad() + + def state_dict(self): + """Return state_dict.""" + return { + "_step": self._step, + } + + def load_state_dict(self, state_dict): + """Load state_dict. This is compatible with reading a Moam state_dict""" + for key, value in state_dict.items(): + if key == "optimizer": + self.optimizer.load_state_dict(state_dict["optimizer"]) + elif key == '_step': + self._step = value + + class TestModel(torch.nn.Module): """Class for testing the Madam optimizer""" @@ -844,9 +978,9 @@ def test_madam(): inf_grad_max_count = 200 if torch.cuda.is_available(): devices_and_l2 = [(torch.device('cuda'), True), - (torch.device('cuda'), False)] - #(torch.device('cpu'), True), - #(torch.device('cpu'), False)] + (torch.device('cuda'), False), + (torch.device('cpu'), True), + (torch.device('cpu'), False)] else: devices_and_l2 = [(torch.device('cpu'), True), (torch.device('cpu'), False)] @@ -922,6 +1056,48 @@ def test_moam(): print("") +def test_foam(): + print("Testing Foam optimizer") + model = TestModel() + # min_target_rms=0.01 is for testing, so the target equals the initial RMS + # and we can more easily tell whether our update has the desired effect. + optimizer = Foam(model.parameters(), + max_lrate=1.0e-03, warm_step=300, + min_target_rms=0.01, + limit_grad_factor=4.0) + + + def get_elems_rms(x: Tensor) -> Tensor: + return ((x ** 2).sum() / x.numel()).sqrt().item() + + for i in range(1000): + if i % 100 == 0: + rms_values = (get_elems_rms(model.first_layers[0].weight), + get_elems_rms(model.first_layers[2].weight), + get_elems_rms(model.conv1.weight), + get_elems_rms(model.conv2.weight)) + print(f"Iter {i} (Foam): stddevs = {rms_values} ") + B = 4 + T = 20 + x = torch.randn(B, T, 100) + y = model(x) + yderiv = torch.randn_like(y) + if i % 190 <= 3 and i > 0: + yderiv *= 100.0 + if i % 550 == 0 and i > 0: + yderiv *= float('inf') + + y.backward(gradient=yderiv) + optimizer.step() + model.zero_grad() + print("") + + state_dict = optimizer.state_dict() + step = optimizer._step + optimizer._step = 0 + optimizer.load_state_dict(state_dict) + assert optimizer._step == step + def test_to_device(): if not torch.cuda.is_available(): @@ -951,8 +1127,10 @@ def main(): #test_to_device() random.seed(0) torch.random.manual_seed(0) + test_foam() + test_moam() test_madam() - #test_moam() + if __name__ == '__main__': diff --git a/egs/librispeech/ASR/conformer_lm/test_dataset.py b/egs/librispeech/ASR/conformer_lm/test_dataset.py index ed38ed11a..b82da7899 100644 --- a/egs/librispeech/ASR/conformer_lm/test_dataset.py +++ b/egs/librispeech/ASR/conformer_lm/test_dataset.py @@ -1,13 +1,34 @@ -import dataset +import k2 import torch +import _k2 +import dataset +import os +from torch import multiprocessing as mp +import torch.distributed as dist +def local_collate_fn(sentences): + return dataset.collate_fn(sentences, bos_sym=1, eos_sym=1, blank_sym=0, debug=True) -train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt') -sampler = dataset.LmBatchSampler(test, symbols_per_batch=1000, world_size=2, rank=0) -a = iter(sampler) -print(str(next(a))) +x = _k2.RaggedInt('[[1]]') # make sure library initialized? -collate_fn=(lambda x:dataset.collate_fn(x, bos_sym=1, eos_sym=1, blank_sym=0, debug=True)) -train_dl = torch.utils.data.DataLoader(test, batch_sampler=sampler, collate_fn=collate_fn) -x = iter(train_dl) -print(str(next(x))) +if __name__ == '__main__': + + #mp.set_start_method('spawn') + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12344" + + dist.init_process_group(backend="nccl", group_name="main", + rank=0, world_size=1) + + train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt') + sampler = dataset.LmBatchSampler(test, symbols_per_batch=5000, world_size=2, rank=0) + print("len(sampler) = ", len(sampler)) + + a = iter(sampler) + print(str(next(a))) + + train_dl = torch.utils.data.DataLoader(test, batch_sampler=sampler, + collate_fn=local_collate_fn, + num_workers=2) + x = iter(train_dl) + print(str(next(x))) diff --git a/egs/librispeech/ASR/conformer_lm/train.py b/egs/librispeech/ASR/conformer_lm/train.py index 0b7e49db5..5ca267147 100755 --- a/egs/librispeech/ASR/conformer_lm/train.py +++ b/egs/librispeech/ASR/conformer_lm/train.py @@ -35,7 +35,7 @@ from lhotse.utils import fix_random_seed from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter -from madam import Moam +from madam import Foam from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl @@ -138,7 +138,7 @@ def get_params() -> AttributeDict: "blank_sym": 0, "bos_sym": 1, "eos_sym": 1, - "start_epoch": 0, + "start_epoch": 3, "num_epochs": 20, "num_valid_batches": 200, "symbols_per_batch": 5000, @@ -155,8 +155,7 @@ def get_params() -> AttributeDict: "attention_dim": 512, "nhead": 8, "num_decoder_layers": 6, - "lr_factor": 2.0, - "warm_step": 20000, + "max_lrate": 5.0e-04 } ) @@ -520,11 +519,9 @@ def run(rank, world_size, args): if world_size > 1: model = DDP(model, device_ids=[rank]) - optimizer = Moam( + optimizer = Foam( model.parameters(), - model_size=params.attention_dim, - factor=params.lr_factor, - warm_step=params.warm_step, + max_lrate=params.max_lrate ) if checkpoints: