mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
Implement Cain with scdaling incorporated;
Removing scaling from ScaledLinear, ScaledConv1d, etc.
This commit is contained in:
parent
8fd9e64fdf
commit
abe5abb688
@ -22,6 +22,7 @@ from typing import Optional, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import Embedding as ScaledEmbedding
|
||||
|
||||
|
||||
def _ntuple(n):
|
||||
@ -154,10 +155,8 @@ class BasicNorm(torch.nn.Module):
|
||||
|
||||
class ScaledLinear(nn.Linear):
|
||||
"""
|
||||
A modified version of nn.Linear where the parameters are scaled before
|
||||
use, via:
|
||||
weight = self.weight * self.weight_scale.exp()
|
||||
bias = self.bias * self.bias_scale.exp()
|
||||
A modified version of nn.Linear that gives an easy way to set the
|
||||
default initial parameter scale.
|
||||
|
||||
Args:
|
||||
Accepts the standard args and kwargs that nn.Linear accepts
|
||||
@ -168,59 +167,26 @@ class ScaledLinear(nn.Linear):
|
||||
(affects the initialization of weight_scale and bias_scale).
|
||||
Another option, if you want to do something like this, is
|
||||
to re-initialize the parameters.
|
||||
initial_speed: this affects how fast the parameter will
|
||||
learn near the start of training; you can set it to a
|
||||
value less than one if you suspect that a module
|
||||
is contributing to instability near the start of training.
|
||||
Nnote: regardless of the use of this option, it's best to
|
||||
use schedulers like Noam that have a warm-up period.
|
||||
Alternatively you can set it to more than 1 if you want it to
|
||||
initially train faster. Must be greater than 0.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
initial_scale: float = 1.0,
|
||||
initial_speed: float = 1.0,
|
||||
**kwargs
|
||||
):
|
||||
super(ScaledLinear, self).__init__(*args, **kwargs)
|
||||
initial_scale = torch.tensor(initial_scale).log()
|
||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
if self.bias is not None:
|
||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
else:
|
||||
self.register_parameter("bias_scale", None)
|
||||
|
||||
self._reset_parameters(
|
||||
initial_speed
|
||||
) # Overrides the reset_parameters in nn.Linear
|
||||
|
||||
def _reset_parameters(self, initial_speed: float):
|
||||
std = 0.1 / initial_speed
|
||||
a = (3 ** 0.5) * std
|
||||
nn.init.uniform_(self.weight, -a, a)
|
||||
if self.bias is not None:
|
||||
nn.init.constant_(self.bias, 0.0)
|
||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
||||
with torch.no_grad():
|
||||
self.weight_scale += torch.tensor(scale / std).log()
|
||||
self.weight[:] *= initial_scale
|
||||
if self.bias is not None:
|
||||
self.bias[:] *= initial_scale
|
||||
|
||||
def get_weight(self):
|
||||
return self.weight * self.weight_scale.exp()
|
||||
def get_weight(self): # not needed any more but kept for back compatibility
|
||||
return self.weight
|
||||
|
||||
def get_bias(self):
|
||||
if self.bias is None or self.bias_scale is None:
|
||||
return None
|
||||
return self.bias
|
||||
|
||||
return self.bias * self.bias_scale.exp()
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return torch.nn.functional.linear(
|
||||
input, self.get_weight(), self.get_bias()
|
||||
)
|
||||
|
||||
|
||||
class ScaledConv1d(nn.Conv1d):
|
||||
@ -229,66 +195,20 @@ class ScaledConv1d(nn.Conv1d):
|
||||
self,
|
||||
*args,
|
||||
initial_scale: float = 1.0,
|
||||
initial_speed: float = 1.0,
|
||||
**kwargs
|
||||
):
|
||||
super(ScaledConv1d, self).__init__(*args, **kwargs)
|
||||
initial_scale = torch.tensor(initial_scale).log()
|
||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
if self.bias is not None:
|
||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
else:
|
||||
self.register_parameter("bias_scale", None)
|
||||
self._reset_parameters(
|
||||
initial_speed
|
||||
) # Overrides the reset_parameters in base class
|
||||
|
||||
def _reset_parameters(self, initial_speed: float):
|
||||
std = 0.1 / initial_speed
|
||||
a = (3 ** 0.5) * std
|
||||
nn.init.uniform_(self.weight, -a, a)
|
||||
if self.bias is not None:
|
||||
nn.init.constant_(self.bias, 0.0)
|
||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
||||
with torch.no_grad():
|
||||
self.weight_scale += torch.tensor(scale / std).log()
|
||||
self.weight[:] *= initial_scale
|
||||
if self.bias is not None:
|
||||
self.bias[:] *= initial_scale
|
||||
|
||||
def get_weight(self):
|
||||
return self.weight * self.weight_scale.exp()
|
||||
|
||||
def get_bias(self):
|
||||
bias = self.bias
|
||||
bias_scale = self.bias_scale
|
||||
if bias is None or bias_scale is None:
|
||||
return None
|
||||
return bias * bias_scale.exp()
|
||||
def get_weight(self): # TODO: delete
|
||||
return self.weight
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
F = torch.nn.functional
|
||||
if self.padding_mode != "zeros":
|
||||
return F.conv1d(
|
||||
F.pad(
|
||||
input,
|
||||
self._reversed_padding_repeated_twice,
|
||||
mode=self.padding_mode,
|
||||
),
|
||||
self.get_weight(),
|
||||
self.get_bias(),
|
||||
self.stride,
|
||||
(0,),
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
return F.conv1d(
|
||||
input,
|
||||
self.get_weight(),
|
||||
self.get_bias(),
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
def get_bias(self): # TODO: delete
|
||||
return self.bias
|
||||
|
||||
|
||||
class ScaledConv2d(nn.Conv2d):
|
||||
@ -297,70 +217,20 @@ class ScaledConv2d(nn.Conv2d):
|
||||
self,
|
||||
*args,
|
||||
initial_scale: float = 1.0,
|
||||
initial_speed: float = 1.0,
|
||||
**kwargs
|
||||
):
|
||||
super(ScaledConv2d, self).__init__(*args, **kwargs)
|
||||
initial_scale = torch.tensor(initial_scale).log()
|
||||
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
if self.bias is not None:
|
||||
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
||||
else:
|
||||
self.register_parameter("bias_scale", None)
|
||||
self._reset_parameters(
|
||||
initial_speed
|
||||
) # Overrides the reset_parameters in base class
|
||||
|
||||
def _reset_parameters(self, initial_speed: float):
|
||||
std = 0.1 / initial_speed
|
||||
a = (3 ** 0.5) * std
|
||||
nn.init.uniform_(self.weight, -a, a)
|
||||
if self.bias is not None:
|
||||
nn.init.constant_(self.bias, 0.0)
|
||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
||||
with torch.no_grad():
|
||||
self.weight_scale += torch.tensor(scale / std).log()
|
||||
self.weight[:] *= initial_scale
|
||||
if self.bias is not None:
|
||||
self.bias[:] *= initial_scale
|
||||
|
||||
def get_weight(self):
|
||||
return self.weight * self.weight_scale.exp()
|
||||
return self.weight
|
||||
|
||||
def get_bias(self):
|
||||
# see https://github.com/pytorch/pytorch/issues/24135
|
||||
bias = self.bias
|
||||
bias_scale = self.bias_scale
|
||||
if bias is None or bias_scale is None:
|
||||
return None
|
||||
return bias * bias_scale.exp()
|
||||
return self.bias
|
||||
|
||||
def _conv_forward(self, input, weight):
|
||||
F = torch.nn.functional
|
||||
if self.padding_mode != "zeros":
|
||||
return F.conv2d(
|
||||
F.pad(
|
||||
input,
|
||||
self._reversed_padding_repeated_twice,
|
||||
mode=self.padding_mode,
|
||||
),
|
||||
weight,
|
||||
self.get_bias(),
|
||||
self.stride,
|
||||
(0, 0),
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
return F.conv2d(
|
||||
input,
|
||||
weight,
|
||||
self.get_bias(),
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return self._conv_forward(input, self.get_weight())
|
||||
|
||||
|
||||
class ActivationBalancer(torch.nn.Module):
|
||||
@ -464,179 +334,6 @@ class DoubleSwish(torch.nn.Module):
|
||||
return DoubleSwishFunction.apply(x)
|
||||
|
||||
|
||||
class ScaledEmbedding(nn.Module):
|
||||
r"""This is a modified version of nn.Embedding that introduces a learnable scale
|
||||
on the parameters. Note: due to how we initialize it, it's best used with
|
||||
schedulers like Noam that have a warmup period.
|
||||
|
||||
It is a simple lookup table that stores embeddings of a fixed dictionary and size.
|
||||
|
||||
This module is often used to store word embeddings and retrieve them using indices.
|
||||
The input to the module is a list of indices, and the output is the corresponding
|
||||
word embeddings.
|
||||
|
||||
Args:
|
||||
num_embeddings (int): size of the dictionary of embeddings
|
||||
embedding_dim (int): the size of each embedding vector
|
||||
padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
|
||||
(initialized to zeros) whenever it encounters the index.
|
||||
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
|
||||
is renormalized to have norm :attr:`max_norm`.
|
||||
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
|
||||
scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
|
||||
the words in the mini-batch. Default ``False``.
|
||||
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
|
||||
See Notes for more details regarding sparse gradients.
|
||||
|
||||
initial_speed (float, optional): This affects how fast the parameter will
|
||||
learn near the start of training; you can set it to a value less than
|
||||
one if you suspect that a module is contributing to instability near
|
||||
the start of training. Nnote: regardless of the use of this option,
|
||||
it's best to use schedulers like Noam that have a warm-up period.
|
||||
Alternatively you can set it to more than 1 if you want it to
|
||||
initially train faster. Must be greater than 0.
|
||||
|
||||
|
||||
Attributes:
|
||||
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
|
||||
initialized from :math:`\mathcal{N}(0, 1)`
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract
|
||||
- Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
|
||||
|
||||
.. note::
|
||||
Keep in mind that only a limited number of optimizers support
|
||||
sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`),
|
||||
:class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`)
|
||||
|
||||
.. note::
|
||||
With :attr:`padding_idx` set, the embedding vector at
|
||||
:attr:`padding_idx` is initialized to all zeros. However, note that this
|
||||
vector can be modified afterwards, e.g., using a customized
|
||||
initialization method, and thus changing the vector used to pad the
|
||||
output. The gradient for this vector from :class:`~torch.nn.Embedding`
|
||||
is always zero.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> # an Embedding module containing 10 tensors of size 3
|
||||
>>> embedding = nn.Embedding(10, 3)
|
||||
>>> # a batch of 2 samples of 4 indices each
|
||||
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
|
||||
>>> embedding(input)
|
||||
tensor([[[-0.0251, -1.6902, 0.7172],
|
||||
[-0.6431, 0.0748, 0.6969],
|
||||
[ 1.4970, 1.3448, -0.9685],
|
||||
[-0.3677, -2.7265, -0.1685]],
|
||||
|
||||
[[ 1.4970, 1.3448, -0.9685],
|
||||
[ 0.4362, -0.4004, 0.9400],
|
||||
[-0.6431, 0.0748, 0.6969],
|
||||
[ 0.9124, -2.3616, 1.1151]]])
|
||||
|
||||
|
||||
>>> # example with padding_idx
|
||||
>>> embedding = nn.Embedding(10, 3, padding_idx=0)
|
||||
>>> input = torch.LongTensor([[0,2,0,5]])
|
||||
>>> embedding(input)
|
||||
tensor([[[ 0.0000, 0.0000, 0.0000],
|
||||
[ 0.1535, -2.0309, 0.9315],
|
||||
[ 0.0000, 0.0000, 0.0000],
|
||||
[-0.1655, 0.9897, 0.0635]]])
|
||||
|
||||
"""
|
||||
__constants__ = [
|
||||
"num_embeddings",
|
||||
"embedding_dim",
|
||||
"padding_idx",
|
||||
"scale_grad_by_freq",
|
||||
"sparse",
|
||||
]
|
||||
|
||||
num_embeddings: int
|
||||
embedding_dim: int
|
||||
padding_idx: int
|
||||
scale_grad_by_freq: bool
|
||||
weight: Tensor
|
||||
sparse: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: Optional[int] = None,
|
||||
scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False,
|
||||
initial_speed: float = 1.0,
|
||||
) -> None:
|
||||
super(ScaledEmbedding, self).__init__()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
if padding_idx is not None:
|
||||
if padding_idx > 0:
|
||||
assert (
|
||||
padding_idx < self.num_embeddings
|
||||
), "Padding_idx must be within num_embeddings"
|
||||
elif padding_idx < 0:
|
||||
assert (
|
||||
padding_idx >= -self.num_embeddings
|
||||
), "Padding_idx must be within num_embeddings"
|
||||
padding_idx = self.num_embeddings + padding_idx
|
||||
self.padding_idx = padding_idx
|
||||
self.scale_grad_by_freq = scale_grad_by_freq
|
||||
|
||||
self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters()
|
||||
self.sparse = sparse
|
||||
|
||||
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
|
||||
self.reset_parameters(initial_speed)
|
||||
|
||||
def reset_parameters(self, initial_speed: float = 1.0) -> None:
|
||||
std = 0.1 / initial_speed
|
||||
nn.init.normal_(self.weight, std=std)
|
||||
nn.init.constant_(self.scale, torch.tensor(1.0 / std).log())
|
||||
|
||||
if self.padding_idx is not None:
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
F = torch.nn.functional
|
||||
scale = self.scale.exp()
|
||||
if input.numel() < self.num_embeddings:
|
||||
return (
|
||||
F.embedding(
|
||||
input,
|
||||
self.weight,
|
||||
self.padding_idx,
|
||||
None,
|
||||
2.0, # None, 2.0 relate to normalization
|
||||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
)
|
||||
* scale
|
||||
)
|
||||
else:
|
||||
return F.embedding(
|
||||
input,
|
||||
self.weight * scale,
|
||||
self.padding_idx,
|
||||
None,
|
||||
2.0, # None, 2.0 relates to normalization
|
||||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = "{num_embeddings}, {embedding_dim}, scale={scale}"
|
||||
if self.padding_idx is not None:
|
||||
s += ", padding_idx={padding_idx}"
|
||||
if self.scale_grad_by_freq is not False:
|
||||
s += ", scale_grad_by_freq={scale_grad_by_freq}"
|
||||
if self.sparse is not False:
|
||||
s += ", sparse=True"
|
||||
return s.format(**self.__dict__)
|
||||
|
||||
|
||||
def _test_activation_balancer_sign():
|
||||
|
@ -205,9 +205,16 @@ class Abel(Optimizer):
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-08).
|
||||
target_rms (float, optional): target root-mean-square value of
|
||||
x_norm in factorization (conceptually). Actually this now just becomes
|
||||
a factor in the learning rate, we may remove it at some point.
|
||||
|
||||
parameters, when we normalize them (only conceptually!)..
|
||||
actually this now just becomes
|
||||
a factor in the learning rate for non-scalar parameters.
|
||||
rms_eps (float, optional): epsilon value such that when parameter
|
||||
rms value goes below this, we stop learning slower for that parameter,
|
||||
i.e. we treat the rms as if it were rms_eps.
|
||||
rms_max (float, optional): maximum root-mean-square value for
|
||||
non-scalar parameters, such that we don't let the parameter
|
||||
values exceed this; we'll shrink any parameter matrices that become
|
||||
larger than this.
|
||||
|
||||
.. _Adam\: A Method for Stochastic Optimization:
|
||||
https://arxiv.org/abs/1412.6980
|
||||
@ -438,6 +445,9 @@ class Cain(Optimizer):
|
||||
co-ordinates every so often, just in the optimizer, to separate big/small
|
||||
directions.
|
||||
|
||||
This version of Cain absorbs the learned scales of the parameter matrices into
|
||||
the optimizer.
|
||||
|
||||
Eve is a modified version of AdamW with a special
|
||||
way of setting the weight-decay / shrinkage-factor, which is designed to make the
|
||||
rms of the parameters approach a particular target_rms (default: 0.1). This is
|
||||
@ -456,12 +466,17 @@ class Cain(Optimizer):
|
||||
running averages of gradient and its square (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay coefficient (default: 3e-4;
|
||||
this value means that the weight would decay significantly after
|
||||
about 3k minibatches. Is not multiplied by learning rate, but
|
||||
is conditional on RMS-value of parameter being > target_rms.
|
||||
target_rms (float, optional): target root-mean-square value of
|
||||
parameters, if they fall below this we will stop applying weight decay.
|
||||
rms_eps (float, optional): epsilon value such that when parameter
|
||||
rms value goes below this, we stop learning slower for that parameter,
|
||||
i.e. we treat the rms as if it were rms_eps.
|
||||
rms_max (float, optional): maximum root-mean-square value for
|
||||
non-scalar parameters, such that we don't let the parameter
|
||||
values exceed this; we'll shrink any parameter matrices that become
|
||||
larger than this.
|
||||
|
||||
|
||||
|
||||
|
||||
.. _Adam\: A Method for Stochastic Optimization:
|
||||
@ -478,8 +493,9 @@ class Cain(Optimizer):
|
||||
lr=1e-3,
|
||||
betas=(0.9, 0.98),
|
||||
eps=1e-8,
|
||||
weight_decay=1e-3,
|
||||
target_rms=0.1,
|
||||
rms_eps=1.0e-05,
|
||||
rms_max=10.0,
|
||||
):
|
||||
|
||||
if not 0.0 <= lr:
|
||||
@ -494,18 +510,20 @@ class Cain(Optimizer):
|
||||
raise ValueError(
|
||||
"Invalid beta parameter at index 1: {}".format(betas[1])
|
||||
)
|
||||
if not 0 <= weight_decay <= 0.1:
|
||||
raise ValueError(
|
||||
"Invalid weight_decay value: {}".format(weight_decay)
|
||||
)
|
||||
if not 0 < target_rms <= 10.0:
|
||||
raise ValueError("Invalid target_rms value: {}".format(target_rms))
|
||||
if not 0.1 < rms_max <= 1000.0:
|
||||
raise ValueError("Invalid rms_max value: {}".format(rms_max))
|
||||
if not 0.0 < rms_eps < 0.1:
|
||||
raise ValueError("Invalid rms_eps value: {}".format(rms_eps))
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
target_rms=target_rms,
|
||||
rms_eps=rms_eps,
|
||||
rms_max=rms_max,
|
||||
)
|
||||
super(Cain, self).__init__(params, defaults)
|
||||
|
||||
@ -526,6 +544,12 @@ class Cain(Optimizer):
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
beta1, beta2 = group["betas"]
|
||||
lr = group["lr"]
|
||||
target_rms = group["target_rms"]
|
||||
rms_eps = group["rms_eps"]
|
||||
rms_max = group["rms_max"]
|
||||
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
@ -550,14 +574,61 @@ class Cain(Optimizer):
|
||||
state["exp_avg_sq"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
if p.numel() > 1:
|
||||
# we keep track of the RMS value of the parameters to
|
||||
# help set the learning rate... this emulates Eve.
|
||||
state["param_rms"] = (p**2).mean().sqrt()
|
||||
# we learn the scale on the parameters in a way
|
||||
# that also emulates Eve.
|
||||
state["scale_exp_avg_sq"] = torch.zeros((), device=p.device,
|
||||
dtype=p.dtype)
|
||||
state["scale_exp_avg"] = torch.zeros((), device=p.device,
|
||||
dtype=p.dtype)
|
||||
|
||||
|
||||
step = state["step"]
|
||||
|
||||
if p.numel() > 1:
|
||||
# This part learns the scale on the parameters.
|
||||
# scale_deriv is the derivative w.r.t. a log-space scale on the parameters, i.e.
|
||||
# if we imagine: p = underlying_param * scale.exp(), this is the derivative
|
||||
# w.r.t. scale (it does not actually matter what value `scale` currently has).
|
||||
scale_deriv = (p * grad).sum()
|
||||
scale_exp_avg_sq = state["scale_exp_avg_sq"]
|
||||
scale_exp_avg = state["scale_exp_avg"]
|
||||
scale_exp_avg_sq.mul_(beta2).addcmul_(scale_deriv, scale_deriv,
|
||||
value=1 - beta2)
|
||||
|
||||
scale_exp_avg.mul_(beta1).add_(scale_deriv, alpha=(1-beta1))
|
||||
scale_bias_correction2 = 1 - beta2 ** step
|
||||
|
||||
scale_denom = (scale_exp_avg_sq.sqrt()).add_(group["eps"])
|
||||
|
||||
scale_alpha = -lr*(scale_bias_correction2 ** 0.5)
|
||||
|
||||
# scale_delta is going to be the change in log-scale, i.e. if
|
||||
# p = underlying_param * scale.exp(),
|
||||
# delta is the change in `scale`.
|
||||
scale_delta = scale_alpha * (scale_exp_avg / scale_denom)
|
||||
|
||||
# the following is equivalent to:
|
||||
# if torch.all(state["param_rms"] > rms_max):
|
||||
# delta = -1.0e-03
|
||||
# which means: if the parameter root-mean-square value goes above the
|
||||
# specified limit, start to decay the parameters (and ignore the computed
|
||||
# delta).
|
||||
scale_delta = torch.where(state["param_rms"] > rms_max,
|
||||
torch.tensor(-1.0e-03, device=scale_delta.device,
|
||||
dtype=scale_delta.dtype),
|
||||
scale_delta)
|
||||
|
||||
p.mul_(1 + scale_delta)
|
||||
|
||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
# forget bias_correction1. We use normal momentum. exp_avg really
|
||||
# just stores the moving-average gradient step.
|
||||
bias_correction2 = 1 - beta2 ** self._step_for_bias_correction(state["step"])
|
||||
bias_correction2 = 1 - beta2 ** self._step_for_bias_correction(step)
|
||||
|
||||
grad = self._change_coordinates(grad, state, forward=True)
|
||||
|
||||
@ -566,23 +637,22 @@ class Cain(Optimizer):
|
||||
denom = (exp_avg_sq.sqrt()).add_(group["eps"])
|
||||
|
||||
|
||||
step_size = group["lr"] # forget bias_correction1
|
||||
target_rms = group["target_rms"]
|
||||
weight_decay = group["weight_decay"]
|
||||
|
||||
this_delta = grad / denom
|
||||
this_delta = self._change_coordinates(this_delta, state, forward=False)
|
||||
|
||||
# treat/repurpose exp_avg as moving-average of `this_delta`
|
||||
exp_avg.mul_(beta1).add_(this_delta, alpha=-step_size*(1-beta1)*(bias_correction2 ** 0.5))
|
||||
|
||||
alpha = -lr*(1-beta1)*(bias_correction2 ** 0.5)
|
||||
if p.numel() > 1:
|
||||
# avoid applying this weight-decay on "scaling factors"
|
||||
# (which are scalar).
|
||||
is_above_target_rms = p.norm() > (
|
||||
target_rms * (p.numel() ** 0.5)
|
||||
)
|
||||
p.mul_(1 - (weight_decay * is_above_target_rms))
|
||||
if step % 10 == 0:
|
||||
state["param_rms"].fill_((p**2).mean().sqrt())
|
||||
# imagine param was normalized to rms=target_rms and we stored the
|
||||
# scale as a separate scalar. This is how fast we'd be learning.
|
||||
alpha = (alpha / target_rms) * state["param_rms"].clamp(min=rms_eps)
|
||||
|
||||
# treat/repurpose exp_avg as moving-average of `this_delta`
|
||||
# don't use alpha=alpha as I don't think
|
||||
exp_avg.mul_(beta1).add_(this_delta, alpha=alpha)
|
||||
|
||||
|
||||
p.add_(exp_avg)
|
||||
state["step"] += 1
|
||||
@ -1059,7 +1129,7 @@ def _test_eve_cain():
|
||||
input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
||||
output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
||||
|
||||
for iter in [0,1]:
|
||||
for iter in [1,0]:
|
||||
fix_random_seed(42)
|
||||
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
|
||||
m = torch.nn.Sequential(Linear(E, 200),
|
||||
@ -1071,7 +1141,7 @@ def _test_eve_cain():
|
||||
|
||||
if iter == 0: optim = Eve(m.parameters(), lr=0.003)
|
||||
else: optim = Cain(m.parameters(), lr=0.003)
|
||||
scheduler = Eden(optim, lr_batches=300, lr_epochs=20, verbose=False)
|
||||
scheduler = Eden(optim, lr_batches=200, lr_epochs=10, verbose=False)
|
||||
|
||||
start = timeit.default_timer()
|
||||
for epoch in range(150):
|
||||
@ -1121,5 +1191,7 @@ def _test_eve_cain():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
_test_eve_cain()
|
||||
#_test_eden()
|
||||
|
Loading…
x
Reference in New Issue
Block a user