check some files

This commit is contained in:
luomingshuang 2022-04-11 20:41:32 +08:00
parent fecceee216
commit 16c6e0207b

View File

@ -15,11 +15,9 @@
# limitations under the License. # limitations under the License.
import random from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union
import torch import torch
from torch import Tensor
from torch.optim import Optimizer from torch.optim import Optimizer
@ -59,24 +57,41 @@ class Eve(Optimizer):
https://openreview.net/forum?id=ryQu7f-RZ https://openreview.net/forum?id=ryQu7f-RZ
""" """
def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-8, def __init__(
weight_decay=1e-3, target_rms=0.1): self,
params,
lr=1e-3,
betas=(0.9, 0.98),
eps=1e-8,
weight_decay=1e-3,
target_rms=0.1,
):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps)) raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0: if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) raise ValueError(
"Invalid beta parameter at index 0: {}".format(betas[0])
)
if not 0.0 <= betas[1] < 1.0: if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) raise ValueError(
"Invalid beta parameter at index 1: {}".format(betas[1])
)
if not 0 <= weight_decay <= 0.1: if not 0 <= weight_decay <= 0.1:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay)
)
if not 0 < target_rms <= 10.0: if not 0 < target_rms <= 10.0:
raise ValueError("Invalid target_rms value: {}".format(target_rms)) raise ValueError("Invalid target_rms value: {}".format(target_rms))
defaults = dict(lr=lr, betas=betas, eps=eps, defaults = dict(
weight_decay=weight_decay, lr=lr,
target_rms=target_rms) betas=betas,
eps=eps,
weight_decay=weight_decay,
target_rms=target_rms,
)
super(Eve, self).__init__(params, defaults) super(Eve, self).__init__(params, defaults)
def __setstate__(self, state): def __setstate__(self, state):
@ -96,83 +111,98 @@ class Eve(Optimizer):
loss = closure() loss = closure()
for group in self.param_groups: for group in self.param_groups:
for p in group['params']: for p in group["params"]:
if p.grad is None: if p.grad is None:
continue continue
# Perform optimization step # Perform optimization step
grad = p.grad grad = p.grad
if grad.is_sparse: if grad.is_sparse:
raise RuntimeError('AdamW does not support sparse gradients') raise RuntimeError(
"AdamW does not support sparse gradients"
)
state = self.state[p] state = self.state[p]
# State initialization # State initialization
if len(state) == 0: if len(state) == 0:
state['step'] = 0 state["step"] = 0
# Exponential moving average of gradient values # Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values # Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group['betas'] beta1, beta2 = group["betas"]
state['step'] += 1 state["step"] += 1
bias_correction1 = 1 - beta1 ** state['step'] bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state['step'] bias_correction2 = 1 - beta2 ** state["step"]
# Decay the first and second moment running average coefficient # Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(group['eps']) denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(
group["eps"]
)
step_size = group['lr'] / bias_correction1 step_size = group["lr"] / bias_correction1
target_rms = group['target_rms'] target_rms = group["target_rms"]
weight_decay = group['weight_decay'] weight_decay = group["weight_decay"]
delta = exp_avg / denom
if p.numel() > 1: if p.numel() > 1:
# avoid applying this weight-decay on "scaling factors" # avoid applying this weight-decay on "scaling factors"
# (which are scalar). # (which are scalar).
is_above_target_rms = (p.norm() > (target_rms * (p.numel() ** 0.5))) is_above_target_rms = p.norm() > (
target_rms * (p.numel() ** 0.5)
)
p.mul_(1 - (weight_decay * is_above_target_rms)) p.mul_(1 - (weight_decay * is_above_target_rms))
p.addcdiv_(exp_avg, denom, value=-step_size) p.addcdiv_(exp_avg, denom, value=-step_size)
return loss return loss
class LRScheduler(object): class LRScheduler(object):
""" """
Base-class for learning rate schedulers where the learning-rate depends on both the Base-class for learning rate schedulers where the learning-rate depends on both the
batch and the epoch. batch and the epoch.
""" """
def __init__(self, optimizer: Optimizer, verbose: bool = False): def __init__(self, optimizer: Optimizer, verbose: bool = False):
# Attach optimizer # Attach optimizer
if not isinstance(optimizer, Optimizer): if not isinstance(optimizer, Optimizer):
raise TypeError('{} is not an Optimizer'.format( raise TypeError(
type(optimizer).__name__)) "{} is not an Optimizer".format(type(optimizer).__name__)
)
self.optimizer = optimizer self.optimizer = optimizer
self.verbose = verbose self.verbose = verbose
for group in optimizer.param_groups: for group in optimizer.param_groups:
group.setdefault('initial_lr', group['lr']) group.setdefault("initial_lr", group["lr"])
self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups] self.base_lrs = [
group["initial_lr"] for group in optimizer.param_groups
]
self.epoch = 0 self.epoch = 0
self.batch = 0 self.batch = 0
def state_dict(self): def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`. """Returns the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which It contains an entry for every variable in self.__dict__ which
is not the optimizer. is not the optimizer.
""" """
return {'base_lrs': self.base_lrs, return {
'epoch': self.epoch, "base_lrs": self.base_lrs,
'batch': self.batch} "epoch": self.epoch,
"batch": self.batch,
}
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
"""Loads the schedulers state. """Loads the schedulers state.
@ -184,8 +214,7 @@ class LRScheduler(object):
self.__dict__.update(state_dict) self.__dict__.update(state_dict)
def get_last_lr(self) -> List[float]: def get_last_lr(self) -> List[float]:
""" Return last computed learning rate by current scheduler. Will be a list of float. """Return last computed learning rate by current scheduler. Will be a list of float."""
"""
return self._last_lr return self._last_lr
def get_lr(self): def get_lr(self):
@ -194,7 +223,6 @@ class LRScheduler(object):
# e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
raise NotImplementedError raise NotImplementedError
def step_batch(self, batch: Optional[int] = None) -> None: def step_batch(self, batch: Optional[int] = None) -> None:
# Step the batch index, or just set it. If `batch` is specified, it # Step the batch index, or just set it. If `batch` is specified, it
# must be the batch index from the start of training, i.e. summed over # must be the batch index from the start of training, i.e. summed over
@ -217,24 +245,23 @@ class LRScheduler(object):
self.epoch = self.epoch + 1 self.epoch = self.epoch + 1
self._set_lrs() self._set_lrs()
def _set_lrs(self): def _set_lrs(self):
values = self.get_lr() values = self.get_lr()
assert len(values) == len(self.optimizer.param_groups) assert len(values) == len(self.optimizer.param_groups)
for i, data in enumerate(zip(self.optimizer.param_groups, values)): for i, data in enumerate(zip(self.optimizer.param_groups, values)):
param_group, lr = data param_group, lr = data
param_group['lr'] = lr param_group["lr"] = lr
self.print_lr(self.verbose, i, lr) self.print_lr(self.verbose, i, lr)
self._last_lr = [group['lr'] for group in self.optimizer.param_groups] self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
def print_lr(self, is_verbose, group, lr): def print_lr(self, is_verbose, group, lr):
"""Display the current learning rate. """Display the current learning rate."""
"""
if is_verbose: if is_verbose:
print(f'Epoch={self.epoch}, batch={self.batch}: adjusting learning rate' print(
f' of group {group} to {lr:.4e}.') f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
f" of group {group} to {lr:.4e}."
)
class Eden(LRScheduler): class Eden(LRScheduler):
@ -254,18 +281,27 @@ class Eden(LRScheduler):
20 to 40 epochs, but may need smaller number if dataset is huge 20 to 40 epochs, but may need smaller number if dataset is huge
and you will do few epochs. and you will do few epochs.
""" """
def __init__(self, optimizer: Optimizer,
lr_batches: Union[int, float], def __init__(
lr_epochs: Union[int, float], self,
verbose: bool = False): optimizer: Optimizer,
lr_batches: Union[int, float],
lr_epochs: Union[int, float],
verbose: bool = False,
):
super(Eden, self).__init__(optimizer, verbose) super(Eden, self).__init__(optimizer, verbose)
self.lr_batches = lr_batches self.lr_batches = lr_batches
self.lr_epochs = lr_epochs self.lr_epochs = lr_epochs
def get_lr(self): def get_lr(self):
factor = (((self.batch**2 + self.lr_batches**2) / self.lr_batches**2) ** -0.25 * factor = (
(((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25)) (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2
return [ x * factor for x in self.base_lrs ] ) ** -0.25 * (
((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2)
** -0.25
)
return [x * factor for x in self.base_lrs]
def _test_eden(): def _test_eden():
m = torch.nn.Linear(100, 100) m = torch.nn.Linear(100, 100)
@ -290,5 +326,6 @@ def _test_eden():
print("last lr = ", scheduler.get_last_lr()) print("last lr = ", scheduler.get_last_lr())
print("state dict = ", scheduler.state_dict()) print("state dict = ", scheduler.state_dict())
if __name__ == '__main__':
if __name__ == "__main__":
_test_eden() _test_eden()