# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu) # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch class Noam(object): """ Implements Noam optimizer. Proposed in "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa Args: params: iterable of parameters to optimize or dicts defining parameter groups model_size: attention dimension of the transformer model factor: learning rate factor warm_step: warmup steps """ def __init__( self, params, model_size: int = 256, factor: float = 10.0, warm_step: int = 25000, weight_decay=0, ) -> None: """Construct an Noam object.""" self.optimizer = torch.optim.Adam( params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay ) self._step = 0 self.warmup = warm_step self.factor = factor self.model_size = model_size self._rate = 0 @property def param_groups(self): """Return param_groups.""" return self.optimizer.param_groups def step(self): """Update parameters and rate.""" self._step += 1 rate = self.rate() for p in self.optimizer.param_groups: p["lr"] = rate self._rate = rate self.optimizer.step() def rate(self, step=None): """Implement `lrate` above.""" if step is None: step = self._step return ( self.factor * self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5)) ) def zero_grad(self): """Reset gradient.""" self.optimizer.zero_grad() def state_dict(self): """Return state_dict.""" return { "_step": self._step, "warmup": self.warmup, "factor": self.factor, "model_size": self.model_size, "_rate": self._rate, "optimizer": self.optimizer.state_dict(), } def load_state_dict(self, state_dict): """Load state_dict.""" for key, value in state_dict.items(): if key == "optimizer": self.optimizer.load_state_dict(state_dict["optimizer"]) else: setattr(self, key, value)